Commit 8285a149 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit...

add CFG denoiser implementation for DDIM, PLMS and UniPC (this is the commit when you can run both old and new implementations to compare them)
parent 2d8e4a65
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, shared
from modules import sd_samplers_compvis, sd_samplers_kdiffusion, sd_samplers_timesteps, shared

# imports for functions that previously were here and are used by other modules
from modules.sd_samplers_common import samples_to_image_grid, sample_to_image  # noqa: F401
@@ -6,6 +6,7 @@ from modules.sd_samplers_common import samples_to_image_grid, sample_to_image #
all_samplers = [
    *sd_samplers_kdiffusion.samplers_data_k_diffusion,
    *sd_samplers_compvis.samplers_data_compvis,
    *sd_samplers_timesteps.samplers_data_timesteps,
]
all_samplers_map = {x.name: x for x in all_samplers}

+18 −32
Original line number Diff line number Diff line
@@ -39,7 +39,7 @@ class CFGDenoiser(torch.nn.Module):
    negative prompt.
    """

    def __init__(self, model):
    def __init__(self, model, sampler):
        super().__init__()
        self.inner_model = model
        self.mask = None
@@ -48,6 +48,7 @@ class CFGDenoiser(torch.nn.Module):
        self.step = 0
        self.image_cfg_scale = None
        self.padded_cond_uncond = False
        self.sampler = sampler

    def combine_denoised(self, x_out, conds_list, uncond, cond_scale):
        denoised_uncond = x_out[-uncond.shape[0]:]
@@ -65,6 +66,9 @@ class CFGDenoiser(torch.nn.Module):

        return denoised

    def get_pred_x0(self, x_in, x_out, sigma):
        return x_out

    def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
        if state.interrupted or state.skipped:
            raise sd_samplers_common.InterruptedException
@@ -78,6 +82,9 @@ class CFGDenoiser(torch.nn.Module):

        assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"

        if self.mask is not None:
            x = self.init_latent * self.mask + self.nmask * x

        batch_size = len(conds_list)
        repeats = [len(conds_list[i]) for i in range(batch_size)]

@@ -170,11 +177,6 @@ class CFGDenoiser(torch.nn.Module):

        devices.test_for_nans(x_out, "unet")

        if opts.live_preview_content == "Prompt":
            sd_samplers_common.store_latent(torch.cat([x_out[i:i+1] for i in denoised_image_indexes]))
        elif opts.live_preview_content == "Negative prompt":
            sd_samplers_common.store_latent(x_out[-uncond.shape[0]:])

        if is_edit_model:
            denoised = self.combine_denoised_for_edit_model(x_out, cond_scale)
        elif skip_uncond:
@@ -182,8 +184,16 @@ class CFGDenoiser(torch.nn.Module):
        else:
            denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

        if self.mask is not None:
            denoised = self.init_latent * self.mask + self.nmask * denoised
        self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

        if opts.live_preview_content == "Prompt":
            preview = self.sampler.last_latent
        elif opts.live_preview_content == "Negative prompt":
            preview = self.get_pred_x0(x_in[-uncond.shape[0]:], x_out[-uncond.shape[0]:], sigma)
        else:
            preview = self.get_pred_x0(torch.cat([x_in[i:i+1] for i in denoised_image_indexes]), torch.cat([denoised[i:i+1] for i in denoised_image_indexes]), sigma)

        sd_samplers_common.store_latent(preview)

        after_cfg_callback_params = AfterCFGCallbackParams(denoised, state.sampling_step, state.sampling_steps)
        cfg_after_cfg_callback(after_cfg_callback_params)
@@ -192,27 +202,3 @@ class CFGDenoiser(torch.nn.Module):
        self.step += 1
        return denoised

class TorchHijack:
    def __init__(self, sampler_noises):
        # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
        # implementation.
        self.sampler_noises = deque(sampler_noises)

    def __getattr__(self, item):
        if item == 'randn_like':
            return self.randn_like

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

    def randn_like(self, x):
        if self.sampler_noises:
            noise = self.sampler_noises.popleft()
            if noise.shape == x.shape:
                return noise

        return devices.randn_like(x)
+139 −1
Original line number Diff line number Diff line
from collections import namedtuple
import inspect
from collections import namedtuple, deque
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
import k_diffusion.sampling

SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])

@@ -127,3 +129,139 @@ def replace_torchsde_browinan():


replace_torchsde_browinan()


class TorchHijack:
    def __init__(self, sampler_noises):
        # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
        # implementation.
        self.sampler_noises = deque(sampler_noises)

    def __getattr__(self, item):
        if item == 'randn_like':
            return self.randn_like

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

    def randn_like(self, x):
        if self.sampler_noises:
            noise = self.sampler_noises.popleft()
            if noise.shape == x.shape:
                return noise

        return devices.randn_like(x)


class Sampler:
    def __init__(self, funcname):
        self.funcname = funcname
        self.func = funcname
        self.extra_params = []
        self.sampler_noises = None
        self.stop_at = None
        self.eta = None
        self.config = None  # set by the function calling the constructor
        self.last_latent = None
        self.s_min_uncond = None
        self.s_churn = 0.0
        self.s_tmin = 0.0
        self.s_tmax = float('inf')
        self.s_noise = 1.0

        self.eta_option_field = 'eta_ancestral'
        self.eta_infotext_field = 'Eta'

        self.conditioning_key = shared.sd_model.model.conditioning_key

        self.model_wrap = None
        self.model_wrap_cfg = None

    def callback_state(self, d):
        step = d['i']

        if self.stop_at is not None and step > self.stop_at:
            raise InterruptedException

        state.sampling_step = step
        shared.total_tqdm.update()

    def launch_sampling(self, steps, func):
        state.sampling_steps = steps
        state.sampling_step = 0

        try:
            return func()
        except RecursionError:
            print(
                'Encountered RecursionError during sampling, returning last latent. '
                'rho >5 with a polyexponential scheduler may cause this error. '
                'You should try to use a smaller rho value instead.'
            )
            return self.last_latent
        except InterruptedException:
            return self.last_latent

    def number_of_needed_noises(self, p):
        return p.steps

    def initialize(self, p) -> dict:
        self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
        self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
        self.model_wrap_cfg.step = 0
        self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
        self.eta = p.eta if p.eta is not None else getattr(opts, self.eta_option_field, 0.0)
        self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)

        k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])

        extra_params_kwargs = {}
        for param_name in self.extra_params:
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
                extra_params_kwargs[param_name] = getattr(p, param_name)

        if 'eta' in inspect.signature(self.func).parameters:
            if self.eta != 1.0:
                p.extra_generation_params[self.eta_infotext_field] = self.eta

            extra_params_kwargs['eta'] = self.eta

        if len(self.extra_params) > 0:
            s_churn = getattr(opts, 's_churn', p.s_churn)
            s_tmin = getattr(opts, 's_tmin', p.s_tmin)
            s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
            s_noise = getattr(opts, 's_noise', p.s_noise)

            if s_churn != self.s_churn:
                extra_params_kwargs['s_churn'] = s_churn
                p.s_churn = s_churn
                p.extra_generation_params['Sigma churn'] = s_churn
            if s_tmin != self.s_tmin:
                extra_params_kwargs['s_tmin'] = s_tmin
                p.s_tmin = s_tmin
                p.extra_generation_params['Sigma tmin'] = s_tmin
            if s_tmax != self.s_tmax:
                extra_params_kwargs['s_tmax'] = s_tmax
                p.s_tmax = s_tmax
                p.extra_generation_params['Sigma tmax'] = s_tmax
            if s_noise != self.s_noise:
                extra_params_kwargs['s_noise'] = s_noise
                p.s_noise = s_noise
                p.extra_generation_params['Sigma noise'] = s_noise

        return extra_params_kwargs

    def create_noise_sampler(self, x, sigmas, p):
        """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
        if shared.opts.no_dpmpp_sde_batch_determinism:
            return None

        from k_diffusion.sampling import BrownianTreeNoiseSampler
        sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
        current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
        return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)


+14 −138
Original line number Diff line number Diff line
@@ -4,8 +4,7 @@ import inspect
import k_diffusion.sampling
from modules import devices, sd_samplers_common, sd_samplers_extra, sd_samplers_cfg_denoiser

from modules.processing import StableDiffusionProcessing
from modules.shared import opts, state
from modules.shared import opts
import modules.shared as shared

samplers_k_diffusion = [
@@ -54,133 +53,17 @@ k_diffusion_scheduler = {
}


class TorchHijack:
    def __init__(self, sampler_noises):
        # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
        # implementation.
        self.sampler_noises = deque(sampler_noises)

    def __getattr__(self, item):
        if item == 'randn_like':
            return self.randn_like

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError(f"'{type(self).__name__}' object has no attribute '{item}'")

    def randn_like(self, x):
        if self.sampler_noises:
            noise = self.sampler_noises.popleft()
            if noise.shape == x.shape:
                return noise
class KDiffusionSampler(sd_samplers_common.Sampler):
    def __init__(self, funcname, sd_model):

        return devices.randn_like(x)
        super().__init__(funcname)

        self.extra_params = sampler_extra_params.get(funcname, [])
        self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)

class KDiffusionSampler:
    def __init__(self, funcname, sd_model):
        denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser

        self.model_wrap = denoiser(sd_model, quantize=shared.opts.enable_quantization)
        self.funcname = funcname
        self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
        self.extra_params = sampler_extra_params.get(funcname, [])
        self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap)
        self.sampler_noises = None
        self.stop_at = None
        self.eta = None
        self.config = None  # set by the function calling the constructor
        self.last_latent = None
        self.s_min_uncond = None

        # NOTE: These are also defined in the StableDiffusionProcessing class.
        # They should have been here to begin with but we're going to
        # leave that class __init__ signature alone.
        self.s_churn = 0.0
        self.s_tmin = 0.0
        self.s_tmax = float('inf')
        self.s_noise = 1.0

        self.conditioning_key = sd_model.model.conditioning_key

    def callback_state(self, d):
        step = d['i']
        latent = d["denoised"]
        if opts.live_preview_content == "Combined":
            sd_samplers_common.store_latent(latent)
        self.last_latent = latent

        if self.stop_at is not None and step > self.stop_at:
            raise sd_samplers_common.InterruptedException

        state.sampling_step = step
        shared.total_tqdm.update()

    def launch_sampling(self, steps, func):
        state.sampling_steps = steps
        state.sampling_step = 0

        try:
            return func()
        except RecursionError:
            print(
                'Encountered RecursionError during sampling, returning last latent. '
                'rho >5 with a polyexponential scheduler may cause this error. '
                'You should try to use a smaller rho value instead.'
            )
            return self.last_latent
        except sd_samplers_common.InterruptedException:
            return self.last_latent

    def number_of_needed_noises(self, p):
        return p.steps

    def initialize(self, p: StableDiffusionProcessing):
        self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
        self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
        self.model_wrap_cfg.step = 0
        self.model_wrap_cfg.image_cfg_scale = getattr(p, 'image_cfg_scale', None)
        self.eta = p.eta if p.eta is not None else opts.eta_ancestral
        self.s_min_uncond = getattr(p, 's_min_uncond', 0.0)

        k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])

        extra_params_kwargs = {}
        for param_name in self.extra_params:
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
                extra_params_kwargs[param_name] = getattr(p, param_name)

        if 'eta' in inspect.signature(self.func).parameters:
            if self.eta != 1.0:
                p.extra_generation_params["Eta"] = self.eta

            extra_params_kwargs['eta'] = self.eta

        if len(self.extra_params) > 0:
            s_churn = getattr(opts, 's_churn', p.s_churn)
            s_tmin = getattr(opts, 's_tmin', p.s_tmin)
            s_tmax = getattr(opts, 's_tmax', p.s_tmax) or self.s_tmax # 0 = inf
            s_noise = getattr(opts, 's_noise', p.s_noise)

            if s_churn != self.s_churn:
                extra_params_kwargs['s_churn'] = s_churn
                p.s_churn = s_churn
                p.extra_generation_params['Sigma churn'] = s_churn
            if s_tmin != self.s_tmin:
                extra_params_kwargs['s_tmin'] = s_tmin
                p.s_tmin = s_tmin
                p.extra_generation_params['Sigma tmin'] = s_tmin
            if s_tmax != self.s_tmax:
                extra_params_kwargs['s_tmax'] = s_tmax
                p.s_tmax = s_tmax
                p.extra_generation_params['Sigma tmax'] = s_tmax
            if s_noise != self.s_noise:
                extra_params_kwargs['s_noise'] = s_noise
                p.s_noise = s_noise
                p.extra_generation_params['Sigma noise'] = s_noise

        return extra_params_kwargs
        self.model_wrap_cfg = sd_samplers_cfg_denoiser.CFGDenoiser(self.model_wrap, self)

    def get_sigmas(self, p, steps):
        discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
@@ -232,22 +115,12 @@ class KDiffusionSampler:

        return sigmas

    def create_noise_sampler(self, x, sigmas, p):
        """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
        if shared.opts.no_dpmpp_sde_batch_determinism:
            return None

        from k_diffusion.sampling import BrownianTreeNoiseSampler
        sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
        current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
        return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
        steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)

        sigmas = self.get_sigmas(p, steps)

        sigma_sched = sigmas[steps - t_enc - 1:]

        xi = x + noise * sigma_sched[0]

        extra_params_kwargs = self.initialize(p)
@@ -296,12 +169,14 @@ class KDiffusionSampler:
        extra_params_kwargs = self.initialize(p)
        parameters = inspect.signature(self.func).parameters

        if 'n' in parameters:
            extra_params_kwargs['n'] = steps

        if 'sigma_min' in parameters:
            extra_params_kwargs['sigma_min'] = self.model_wrap.sigmas[0].item()
            extra_params_kwargs['sigma_max'] = self.model_wrap.sigmas[-1].item()
            if 'n' in parameters:
                extra_params_kwargs['n'] = steps
        else:

        if 'sigmas' in parameters:
            extra_params_kwargs['sigmas'] = sigmas

        if self.config.options.get('brownian_noise', False):
@@ -322,3 +197,4 @@ class KDiffusionSampler:

        return samples

+147 −0
Original line number Diff line number Diff line
import torch
import inspect
from modules import devices, sd_samplers_common, sd_samplers_timesteps_impl
from modules.sd_samplers_cfg_denoiser import CFGDenoiser

from modules.shared import opts
import modules.shared as shared

samplers_timesteps = [
    ('k_DDIM', sd_samplers_timesteps_impl.ddim, ['k_ddim'], {}),
    ('k_PLMS', sd_samplers_timesteps_impl.plms, ['k_plms'], {}),
    ('k_UniPC', sd_samplers_timesteps_impl.unipc, ['k_unipc'], {}),
]


samplers_data_timesteps = [
    sd_samplers_common.SamplerData(label, lambda model, funcname=funcname: CompVisSampler(funcname, model), aliases, options)
    for label, funcname, aliases, options in samplers_timesteps
]


class CompVisTimestepsDenoiser(torch.nn.Module):
    def __init__(self, model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.inner_model = model

    def forward(self, input, timesteps, **kwargs):
        return self.inner_model.apply_model(input, timesteps, **kwargs)


class CompVisTimestepsVDenoiser(torch.nn.Module):
    def __init__(self, model, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.inner_model = model

    def predict_eps_from_z_and_v(self, x_t, t, v):
        return self.inner_model.sqrt_alphas_cumprod[t.to(torch.int), None, None, None] * v + self.inner_model.sqrt_one_minus_alphas_cumprod[t.to(torch.int), None, None, None] * x_t

    def forward(self, input, timesteps, **kwargs):
        model_output = self.inner_model.apply_model(input, timesteps, **kwargs)
        e_t = self.predict_eps_from_z_and_v(input, timesteps, model_output)
        return e_t


class CFGDenoiserTimesteps(CFGDenoiser):

    def __init__(self, model, sampler):
        super().__init__(model, sampler)

        self.alphas = model.inner_model.alphas_cumprod

    def get_pred_x0(self, x_in, x_out, sigma):
        ts = int(sigma.item())

        s_in = x_in.new_ones([x_in.shape[0]])
        a_t = self.alphas[ts].item() * s_in
        sqrt_one_minus_at = (1 - a_t).sqrt()

        pred_x0 = (x_in - sqrt_one_minus_at * x_out) / a_t.sqrt()

        return pred_x0


class CompVisSampler(sd_samplers_common.Sampler):
    def __init__(self, funcname, sd_model):
        super().__init__(funcname)

        self.eta_option_field = 'eta_ddim'
        self.eta_infotext_field = 'Eta DDIM'

        denoiser = CompVisTimestepsVDenoiser if sd_model.parameterization == "v" else CompVisTimestepsDenoiser
        self.model_wrap = denoiser(sd_model)
        self.model_wrap_cfg = CFGDenoiserTimesteps(self.model_wrap, self)

    def get_timesteps(self, p, steps):
        discard_next_to_last_sigma = self.config is not None and self.config.options.get('discard_next_to_last_sigma', False)
        if opts.always_discard_next_to_last_sigma and not discard_next_to_last_sigma:
            discard_next_to_last_sigma = True
            p.extra_generation_params["Discard penultimate sigma"] = True

        steps += 1 if discard_next_to_last_sigma else 0

        timesteps = torch.clip(torch.asarray(list(range(0, 1000, 1000 // steps)), device=devices.device) + 1, 0, 999)

        return timesteps

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
        steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)

        timesteps = self.get_timesteps(p, steps)
        timesteps_sched = timesteps[:t_enc]

        alphas_cumprod = shared.sd_model.alphas_cumprod
        sqrt_alpha_cumprod = torch.sqrt(alphas_cumprod[timesteps[t_enc]])
        sqrt_one_minus_alpha_cumprod = torch.sqrt(1 - alphas_cumprod[timesteps[t_enc]])

        xi = x * sqrt_alpha_cumprod + noise * sqrt_one_minus_alpha_cumprod

        extra_params_kwargs = self.initialize(p)
        parameters = inspect.signature(self.func).parameters

        if 'timesteps' in parameters:
            extra_params_kwargs['timesteps'] = timesteps_sched
        if 'is_img2img' in parameters:
            extra_params_kwargs['is_img2img'] = True

        self.model_wrap_cfg.init_latent = x
        self.last_latent = x
        extra_args = {
            'cond': conditioning,
            'image_cond': image_conditioning,
            'uncond': unconditional_conditioning,
            'cond_scale': p.cfg_scale,
            's_min_uncond': self.s_min_uncond
        }

        samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

        if self.model_wrap_cfg.padded_cond_uncond:
            p.extra_generation_params["Pad conds"] = True

        return samples

    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
        steps = steps or p.steps
        timesteps = self.get_timesteps(p, steps)

        extra_params_kwargs = self.initialize(p)
        parameters = inspect.signature(self.func).parameters

        if 'timesteps' in parameters:
            extra_params_kwargs['timesteps'] = timesteps

        self.last_latent = x
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
            'cond': conditioning,
            'image_cond': image_conditioning,
            'uncond': unconditional_conditioning,
            'cond_scale': p.cfg_scale,
            's_min_uncond': self.s_min_uncond
        }, disable=False, callback=self.callback_state, **extra_params_kwargs))

        if self.model_wrap_cfg.padded_cond_uncond:
            p.extra_generation_params["Pad conds"] = True

        return samples
Loading