Commit 5a0db84b authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add infotext

add proper support for recalculating conds in k-diffusion samplers
remove support for compvis samplers
parent 956e69bf
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [
    ('Pad conds', 'pad_cond_uncond'),
    ('VAE Encoder', 'sd_vae_encode_method'),
    ('VAE Decoder', 'sd_vae_decode_method'),
    ('Refiner', 'sd_refiner_checkpoint'),
    ('Refiner switch at', 'sd_refiner_switch_at'),
]


+10 −0
Original line number Diff line number Diff line
@@ -370,6 +370,9 @@ class StableDiffusionProcessing:
        self.uc = self.get_conds_with_caching(prompt_parser.get_learned_conditioning, negative_prompts, self.steps * self.step_multiplier, [self.cached_uc], self.extra_network_data)
        self.c = self.get_conds_with_caching(prompt_parser.get_multicond_learned_conditioning, prompts, self.steps * self.step_multiplier, [self.cached_c], self.extra_network_data)

    def get_conds(self):
        return self.c, self.uc

    def parse_extra_network_prompts(self):
        self.prompts, self.extra_network_data = extra_networks.parse_prompts(self.prompts)

@@ -1251,6 +1254,13 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                with devices.autocast():
                    extra_networks.activate(self, self.extra_network_data)

    def get_conds(self):
        if self.is_hr_pass:
            return self.hr_c, self.hr_uc

        return super().get_conds()


    def parse_extra_network_prompts(self):
        res = super().parse_extra_network_prompts()

+20 −9
Original line number Diff line number Diff line
@@ -131,16 +131,27 @@ replace_torchsde_browinan()

def apply_refiner(sampler):
    completed_ratio = sampler.step / sampler.steps
    if completed_ratio > shared.opts.sd_refiner_switch_at and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint:

    if completed_ratio <= shared.opts.sd_refiner_switch_at:
        return False

    if shared.sd_model.sd_checkpoint_info.title == shared.opts.sd_refiner_checkpoint:
        return False

    refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
    if refiner_checkpoint_info is None:
        raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')

    sampler.p.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
    sampler.p.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at

    with sd_models.SkipWritingToConfig():
        sd_models.reload_model_weights(info=refiner_checkpoint_info)

    devices.torch_gc()

    sampler.p.setup_conds()
    sampler.update_inner_model()

        sampler.p.setup_conds()
    return True

+0 −2
Original line number Diff line number Diff line
@@ -71,8 +71,6 @@ class VanillaStableDiffusionSampler:
        if state.interrupted or state.skipped:
            raise sd_samplers_common.InterruptedException

        sd_samplers_common.apply_refiner(self)

        if self.stop_at is not None and self.step > self.stop_at:
            raise sd_samplers_common.InterruptedException

+16 −8
Original line number Diff line number Diff line
@@ -87,8 +87,9 @@ class CFGDenoiser(torch.nn.Module):
    negative prompt.
    """

    def __init__(self):
    def __init__(self, sampler):
        super().__init__()
        self.sampler = sampler
        self.model_wrap = None
        self.mask = None
        self.nmask = None
@@ -126,11 +127,17 @@ class CFGDenoiser(torch.nn.Module):
    def update_inner_model(self):
        self.model_wrap = None

        c, uc = self.p.get_conds()
        self.sampler.sampler_extra_args['cond'] = c
        self.sampler.sampler_extra_args['uncond'] = uc

    def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
        if state.interrupted or state.skipped:
            raise sd_samplers_common.InterruptedException

        sd_samplers_common.apply_refiner(self)
        if sd_samplers_common.apply_refiner(self):
            cond = self.sampler.sampler_extra_args['cond']
            uncond = self.sampler.sampler_extra_args['uncond']

        # at self.image_cfg_scale == 1.0 produced results for edit model are the same as with normal sampling,
        # so is_edit_model is set to False to support AND composition.
@@ -282,12 +289,12 @@ class TorchHijack:

class KDiffusionSampler:
    def __init__(self, funcname, sd_model):

        self.p = None
        self.funcname = funcname
        self.func = funcname if callable(funcname) else getattr(k_diffusion.sampling, self.funcname)
        self.extra_params = sampler_extra_params.get(funcname, [])
        self.model_wrap_cfg = CFGDenoiser()
        self.sampler_extra_args = {}
        self.model_wrap_cfg = CFGDenoiser(self)
        self.model_wrap = self.model_wrap_cfg.inner_model
        self.sampler_noises = None
        self.stop_at = None
@@ -476,7 +483,7 @@ class KDiffusionSampler:

        self.model_wrap_cfg.init_latent = x
        self.last_latent = x
        extra_args = {
        self.sampler_extra_args = {
            'cond': conditioning,
            'image_cond': image_conditioning,
            'uncond': unconditional_conditioning,
@@ -484,7 +491,7 @@ class KDiffusionSampler:
            's_min_uncond': self.s_min_uncond
        }

        samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))
        samples = self.launch_sampling(t_enc + 1, lambda: self.func(self.model_wrap_cfg, xi, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

        if self.model_wrap_cfg.padded_cond_uncond:
            p.extra_generation_params["Pad conds"] = True
@@ -514,13 +521,14 @@ class KDiffusionSampler:
            extra_params_kwargs['noise_sampler'] = noise_sampler

        self.last_latent = x
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args={
        self.sampler_extra_args = {
            'cond': conditioning,
            'image_cond': image_conditioning,
            'uncond': unconditional_conditioning,
            'cond_scale': p.cfg_scale,
            's_min_uncond': self.s_min_uncond
        }, disable=False, callback=self.callback_state, **extra_params_kwargs))
        }
        samples = self.launch_sampling(steps, lambda: self.func(self.model_wrap_cfg, x, extra_args=self.sampler_extra_args, disable=False, callback=self.callback_state, **extra_params_kwargs))

        if self.model_wrap_cfg.padded_cond_uncond:
            p.extra_generation_params["Pad conds"] = True