Commit 6c7b6ecb authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

alternative refiner implementation

parent 57e8a11d
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -344,6 +344,8 @@ infotext_to_setting_name_mapping = [
    ('Pad conds', 'pad_cond_uncond'),
    ('VAE Encoder', 'sd_vae_encode_method'),
    ('VAE Decoder', 'sd_vae_decode_method'),
    ('Refiner', 'sd_refiner_checkpoint'),
    ('Refiner switch at', 'sd_refiner_switch_at'),
]


+77 −10
Original line number Diff line number Diff line
@@ -178,6 +178,8 @@ class StableDiffusionProcessing:
        self.extra_network_data = None
        self.seeds = None
        self.subseeds = None
        self.recorded_checkpoint = None
        self.recorded_checkpoint_hash = None

        self.step_multiplier = 1
        self.cached_uc = StableDiffusionProcessing.cached_uc
@@ -186,6 +188,7 @@ class StableDiffusionProcessing:
        self.c = None

        self.user = None
        self.image_conditioning = None

    @property
    def sd_model(self):
@@ -377,6 +380,54 @@ class StableDiffusionProcessing:
        """Returns whether generated images need to be written to disk"""
        return opts.samples_save and not self.do_not_save_samples and (opts.save_incomplete_images or not state.interrupted and not state.skipped)

    def run_refiner(self, samples):
        shared.state.nextjob()

        stopped_at = self.sampler.stop_at
        self.sampler = None

        a_is_sdxl = shared.sd_model.is_sdxl

        decoded_samples = decode_latent_batch(shared.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

        refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
        if refiner_checkpoint_info is None:
            raise Exception(f'Could not find checkpoint with name {shared.opts.sd_refiner_checkpoint}')

        self.recorded_checkpoint = shared.sd_model.sd_checkpoint_info.name_for_extra
        self.recorded_checkpoint_hash = shared.sd_model.sd_model_hash
        self.extra_generation_params['Refiner'] = refiner_checkpoint_info.short_title
        self.extra_generation_params['Refiner switch at'] = shared.opts.sd_refiner_switch_at

        with sd_models.SkipWritingToConfig():
            sd_models.reload_model_weights(info=refiner_checkpoint_info)

        devices.torch_gc()
        self.setup_conds()

        b_is_sdxl = shared.sd_model.is_sdxl

        if a_is_sdxl != b_is_sdxl:
            decoded_samples = torch.stack(decoded_samples).float()
            decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
            latent = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
        else:
            latent = samples

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)

        with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
            denoising_strength = self.denoising_strength

            self.denoising_strength = 1.0 - stopped_at / self.steps
            self.image_conditioning = txt2img_image_conditioning(shared.sd_model, latent, self.width, self.height)
            self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model)
            samples = self.sampler.sample_img2img(self, latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))

            self.denoising_strength = denoising_strength

        return samples


class Processed:
    def __init__(self, p: StableDiffusionProcessing, images_list, seed=-1, info="", subseed=None, all_prompts=None, all_negative_prompts=None, all_seeds=None, all_subseeds=None, index_of_first_image=0, infotexts=None, comments=""):
@@ -553,6 +604,9 @@ class DecodedSamples(list):


def decode_latent_batch(model, batch, target_device=None, check_for_nans=False):
    if getattr(batch, 'already_decoded', False):
        return batch

    samples = DecodedSamples()

    for i in range(batch.shape[0]):
@@ -632,8 +686,8 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments=None, iter
        "Seed": p.all_seeds[0] if use_main_prompt else all_seeds[index],
        "Face restoration": (opts.face_restoration_model if p.restore_faces else None),
        "Size": f"{p.width}x{p.height}",
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else shared.sd_model.sd_model_hash),
        "Model": (None if not opts.add_model_name_to_info else shared.sd_model.sd_checkpoint_info.name_for_extra),
        "Model hash": getattr(p, 'sd_model_hash', None if not opts.add_model_hash_to_info or not shared.sd_model.sd_model_hash else p.recorded_checkpoint_hash or shared.sd_model.sd_model_hash),
        "Model": (None if not opts.add_model_name_to_info else p.recorded_checkpoint or shared.sd_model.sd_checkpoint_info.name_for_extra),
        "Variation seed": (None if p.subseed_strength == 0 else (p.all_subseeds[0] if use_main_prompt else all_subseeds[index])),
        "Variation seed strength": (None if p.subseed_strength == 0 else p.subseed_strength),
        "Seed resize from": (None if p.seed_resize_from_w <= 0 or p.seed_resize_from_h <= 0 else f"{p.seed_resize_from_w}x{p.seed_resize_from_h}"),
@@ -666,6 +720,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
    stored_opts = {k: opts.data[k] for k in p.override_settings.keys()}

    try:
        # after running refiner, the refiner model is not unloaded - webui swaps back to main model here
        if shared.sd_model.sd_checkpoint_info.title != opts.sd_model_checkpoint:
            sd_models.reload_model_weights()

        # if no checkpoint override or the override checkpoint can't be found, remove override entry and load opts checkpoint
        if sd_models.checkpoint_aliases.get(p.override_settings.get('sd_model_checkpoint')) is None:
            p.override_settings.pop('sd_model_checkpoint', None)
@@ -737,6 +795,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    infotexts = []
    output_images = []

    have_refiner = shared.opts.sd_refiner_switch_at < 1.0 and shared.sd_model.sd_checkpoint_info.title != shared.opts.sd_refiner_checkpoint

    with torch.no_grad(), p.sd_model.ema_scope():
        with devices.autocast():
            p.init(p.all_prompts, p.all_seeds, p.all_subseeds)
@@ -750,6 +810,10 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
        if state.job_count == -1:
            state.job_count = p.n_iter

        if have_refiner:
            state.job_count *= 2
            shared.total_tqdm.updateTotal(p.steps * state.job_count // 2)

        for n in range(p.n_iter):
            p.iteration = n

@@ -798,15 +862,18 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            if have_refiner:
                p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))

            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
            if opts.sd_vae_decode_method != 'Full':
                p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method

            if have_refiner:
                samples_ddim = p.run_refiner(samples_ddim)

            x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

            x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -989,6 +1056,8 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        self.hr_uc = None

    def init(self, all_prompts, all_seeds, all_subseeds):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        if self.enable_hr:
            if self.hr_checkpoint_name:
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
@@ -1065,8 +1134,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
                self.extra_generation_params["Hires upscaler"] = self.hr_upscaler

    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=seeds, subseeds=subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
        del x
+17 −1
Original line number Diff line number Diff line
@@ -289,10 +289,26 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
    return res


class SkipWritingToConfig:
    """This context manager prevents load_model_weights from writing checkpoint name to the config when it loads weight."""

    skip = False
    previous = None

    def __enter__(self):
        self.previous = SkipWritingToConfig.skip
        SkipWritingToConfig.skip = True
        return self

    def __exit__(self, exc_type, exc_value, exc_traceback):
        SkipWritingToConfig.skip = self.previous


def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer):
    sd_model_hash = checkpoint_info.calculate_shorthash()
    timer.record("calculate hash")

    if not SkipWritingToConfig.skip:
        shared.opts.data["sd_model_checkpoint"] = checkpoint_info.title

    if state_dict is None:
+1 −1
Original line number Diff line number Diff line
@@ -44,7 +44,7 @@ class VanillaStableDiffusionSampler:
        return 0

    def launch_sampling(self, steps, func):
        state.sampling_steps = steps
        state.sampling_steps = self.stop_at if self.stop_at is not None else steps
        state.sampling_step = 0

        try:
+1 −1
Original line number Diff line number Diff line
@@ -305,7 +305,7 @@ class KDiffusionSampler:
        shared.total_tqdm.update()

    def launch_sampling(self, steps, func):
        state.sampling_steps = steps
        state.sampling_steps = self.stop_at if self.stop_at is not None else steps
        state.sampling_step = 0

        try:
Loading