Commit d64b4516 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

added support for automatically installing latest k-diffusion

added eta parameter to parameters output for generated images
split eta settings into ancestral and ddim (because they have different default values)
parent 9be0d1b8
Loading
Loading
Loading
Loading
+7 −0
Original line number Diff line number Diff line
@@ -113,6 +113,13 @@ if not skip_torch_cuda_test:
if not is_installed("k_diffusion.sampling"):
    run_pip(f"install {k_diffusion_package}", "k-diffusion")

if not check_run_python("import k_diffusion; import inspect; assert 'eta' in inspect.signature(k_diffusion.sampling.sample_euler_ancestral).parameters"):
    print(f"k-diffusion does not have 'eta' parameter; reinstalling latest version")
    try:
        run_pip(f"install --upgrade --force-reinstall {k_diffusion_package}", "k-diffusion")
    except RuntimeError as e:
        print(str(e))

if not is_installed("gfpgan"):
    run_pip(f"install {gfpgan_package}", "gfpgan")

+5 −4
Original line number Diff line number Diff line
@@ -49,7 +49,7 @@ def apply_color_correction(correction, image):


class StableDiffusionProcessing:
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None):
    def __init__(self, sd_model=None, outpath_samples=None, outpath_grids=None, prompt="", styles=None, seed=-1, subseed=-1, subseed_strength=0, seed_resize_from_h=-1, seed_resize_from_w=-1, seed_enable_extras=True, sampler_index=0, batch_size=1, n_iter=1, steps=50, cfg_scale=7.0, width=512, height=512, restore_faces=False, tiling=False, do_not_save_samples=False, do_not_save_grid=False, extra_generation_params=None, overlay_images=None, negative_prompt=None, eta=None):
        self.sd_model = sd_model
        self.outpath_samples: str = outpath_samples
        self.outpath_grids: str = outpath_grids
@@ -75,11 +75,11 @@ class StableDiffusionProcessing:
        self.do_not_save_grid: bool = do_not_save_grid
        self.extra_generation_params: dict = extra_generation_params or {}
        self.overlay_images = overlay_images
        self.eta = eta
        self.paste_to = None
        self.color_corrections = None
        self.denoising_strength: float = 0

        self.eta = opts.eta
        self.ddim_discretize = opts.ddim_discretize
        self.s_churn = opts.s_churn
        self.s_tmin = opts.s_tmin
@@ -271,6 +271,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w == 0 or p.seed_resize_from_h == 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
        "Denoising strength": getattr(p, 'denoising_strength', None),
        "Eta": (None if p.sampler.eta == p.sampler.default_eta else p.sampler.eta),
    }

    generation_params.update(p.extra_generation_params)
+44 −39
Original line number Diff line number Diff line
@@ -40,10 +40,8 @@ samplers_for_img2img = [x for x in samplers if x.name != 'PLMS']

sampler_extra_params = {
    'sample_euler': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
    'sample_euler_ancestral': ['eta'],
    'sample_heun': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
    'sample_dpm_2': ['s_churn', 's_tmin', 's_tmax', 's_noise'],
    'sample_dpm_2_ancestral': ['eta'],
}

def setup_img2img_steps(p, steps=None):
@@ -101,6 +99,8 @@ class VanillaStableDiffusionSampler:
        self.init_latent = None
        self.sampler_noises = None
        self.step = 0
        self.eta = None
        self.default_eta = 0.0

    def number_of_needed_noises(self, p):
        return 0
@@ -123,20 +123,29 @@ class VanillaStableDiffusionSampler:
        self.step += 1
        return res

    def initialize(self, p):
        self.eta = p.eta or opts.eta_ddim

        for fieldname in ['p_sample_ddim', 'p_sample_plms']:
            if hasattr(self.sampler, fieldname):
                setattr(self.sampler, fieldname, self.p_sample_ddim_hook)

        self.mask = p.mask if hasattr(p, 'mask') else None
        self.nmask = p.nmask if hasattr(p, 'nmask') else None

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
        steps, t_enc = setup_img2img_steps(p, steps)

        # existing code fails with cetain step counts, like 9
        try:
            self.sampler.make_schedule(ddim_num_steps=steps,  ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
            self.sampler.make_schedule(ddim_num_steps=steps,  ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)
        except Exception:
            self.sampler.make_schedule(ddim_num_steps=steps+1,ddim_eta=p.ddim_eta, ddim_discretize=p.ddim_discretize, verbose=False)
            self.sampler.make_schedule(ddim_num_steps=steps+1, ddim_eta=self.eta, ddim_discretize=p.ddim_discretize, verbose=False)

        x1 = self.sampler.stochastic_encode(x, torch.tensor([t_enc] * int(x.shape[0])).to(shared.device), noise=noise)

        self.sampler.p_sample_ddim = self.p_sample_ddim_hook
        self.mask = p.mask if hasattr(p, 'mask') else None
        self.nmask = p.nmask if hasattr(p, 'nmask') else None
        self.initialize(p)

        self.init_latent = x
        self.step = 0

@@ -145,11 +154,8 @@ class VanillaStableDiffusionSampler:
        return samples

    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
        for fieldname in ['p_sample_ddim', 'p_sample_plms']:
            if hasattr(self.sampler, fieldname):
                setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
        self.mask = None
        self.nmask = None
        self.initialize(p)

        self.init_latent = None
        self.step = 0

@@ -157,9 +163,9 @@ class VanillaStableDiffusionSampler:

        # existing code fails with cetin step counts, like 9
        try:
            samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
            samples_ddim, _ = self.sampler.sample(S=steps, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)
        except Exception:
            samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=p.eta)
            samples_ddim, _ = self.sampler.sample(S=steps+1, conditioning=conditioning, batch_size=int(x.shape[0]), shape=x[0].shape, verbose=False, unconditional_guidance_scale=p.cfg_scale, unconditional_conditioning=unconditional_conditioning, x_T=x, eta=self.eta)

        return samples_ddim

@@ -237,6 +243,8 @@ class KDiffusionSampler:
        self.sampler_noises = None
        self.sampler_noise_index = 0
        self.stop_at = None
        self.eta = None
        self.default_eta = 1.0

    def callback_state(self, d):
        store_latent(d["denoised"])
@@ -255,22 +263,12 @@ class KDiffusionSampler:
        self.sampler_noise_index += 1
        return res

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
        steps, t_enc = setup_img2img_steps(p, steps)

        sigmas = self.model_wrap.get_sigmas(steps)

        noise = noise * sigmas[steps - t_enc - 1]

        xi = x + noise

        sigma_sched = sigmas[steps - t_enc - 1:]

    def initialize(self, p):
        self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
        self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
        self.model_wrap_cfg.init_latent = x
        self.model_wrap.step = 0
        self.sampler_noise_index = 0
        self.eta = p.eta or opts.eta_ancestral

        if hasattr(k_diffusion.sampling, 'trange'):
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)
@@ -283,6 +281,25 @@ class KDiffusionSampler:
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
                extra_params_kwargs[param_name] = getattr(p, param_name)

        if 'eta' in inspect.signature(self.func).parameters:
            extra_params_kwargs['eta'] = self.eta

        return extra_params_kwargs

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None):
        steps, t_enc = setup_img2img_steps(p, steps)

        sigmas = self.model_wrap.get_sigmas(steps)

        noise = noise * sigmas[steps - t_enc - 1]
        xi = x + noise

        extra_params_kwargs = self.initialize(p)

        sigma_sched = sigmas[steps - t_enc - 1:]

        self.model_wrap_cfg.init_latent = x

        return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)

    def sample(self, p, x, conditioning, unconditional_conditioning, steps=None):
@@ -291,19 +308,7 @@ class KDiffusionSampler:
        sigmas = self.model_wrap.get_sigmas(steps)
        x = x * sigmas[0]

        self.model_wrap_cfg.step = 0
        self.sampler_noise_index = 0

        if hasattr(k_diffusion.sampling, 'trange'):
            k_diffusion.sampling.trange = lambda *args, **kwargs: extended_trange(self, *args, **kwargs)

        if self.sampler_noises is not None:
            k_diffusion.sampling.torch = TorchHijack(self)

        extra_params_kwargs = {}
        for param_name in self.extra_params:
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
                extra_params_kwargs[param_name] = getattr(p, param_name)
        extra_params_kwargs = self.initialize(p)

        samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)

+3 −2
Original line number Diff line number Diff line
@@ -221,7 +221,8 @@ options_templates.update(options_section(('ui', "User interface"), {
}))

options_templates.update(options_section(('sampler-params', "Sampler parameters"), {
  "eta": OptionInfo(0.0, "DDIM and K Ancestral eta", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  "eta_ddim": OptionInfo(0.0, "eta (noise multiplier) for DDIM", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  "eta_ancestral": OptionInfo(1.0, "eta (noise multiplier) for ancestral samplers", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  "ddim_discretize": OptionInfo('uniform', "img2img DDIM discretize", gr.Radio, {"choices": ['uniform', 'quad']}),
  's_churn': OptionInfo(0.0, "sigma churn", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
  's_tmin':  OptionInfo(0.0, "sigma tmin",  gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
+6 −6
Original line number Diff line number Diff line
@@ -91,7 +91,7 @@ axis_options = [
    AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label),
    AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label),
    AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label),
    AxisOption("DDIM Eta",    float, apply_field("ddim_eta"), format_value_add_label),
    AxisOption("Eta", float, apply_field("eta"), format_value_add_label),
    AxisOptionImg2Img("Denoising", float, apply_field("denoising_strength"), format_value_add_label),  # as it is now all AxisOptionImg2Img items must go after AxisOption ones
]