Commit 68659838 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

send noisy latent into refiner without adding noise

parent 3f828206
Loading
Loading
Loading
Loading
+17 −16
Original line number Diff line number Diff line
@@ -384,11 +384,11 @@ class StableDiffusionProcessing:
        shared.state.nextjob()

        stopped_at = self.sampler.stop_at
        noisy_output = self.sampler.noisy_output
        self.sampler = None

        a_is_sdxl = shared.sd_model.is_sdxl

        decoded_samples = decode_latent_batch(shared.sd_model, samples, target_device=devices.cpu, check_for_nans=True)
        decoded_noisy = decode_latent_batch(shared.sd_model, noisy_output, target_device=devices.cpu, check_for_nans=True)

        refiner_checkpoint_info = sd_models.get_closet_checkpoint_match(shared.opts.sd_refiner_checkpoint)
        if refiner_checkpoint_info is None:
@@ -408,21 +408,21 @@ class StableDiffusionProcessing:
        b_is_sdxl = shared.sd_model.is_sdxl

        if a_is_sdxl != b_is_sdxl:
            decoded_samples = torch.stack(decoded_samples).float()
            decoded_samples = torch.clamp((decoded_samples + 1.0) / 2.0, min=0.0, max=1.0)
            latent = images_tensor_to_samples(decoded_samples, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
            decoded_noisy = torch.stack(decoded_noisy).float()
            decoded_noisy = torch.clamp((decoded_noisy + 1.0) / 2.0, min=0.0, max=1.0)
            noisy_latent = images_tensor_to_samples(decoded_noisy, approximation_indexes.get(opts.sd_vae_encode_method), shared.sd_model)
        else:
            latent = samples
            noisy_latent = noisy_output

        x = create_random_tensors([opt_C, self.height // opt_f, self.width // opt_f], seeds=self.seeds, subseeds=self.subseeds, subseed_strength=self.subseed_strength, seed_resize_from_h=self.seed_resize_from_h, seed_resize_from_w=self.seed_resize_from_w, p=self)
        x = torch.zeros_like(noisy_latent)

        with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
            denoising_strength = self.denoising_strength

            self.denoising_strength = 1.0 - stopped_at / self.steps
            self.image_conditioning = txt2img_image_conditioning(shared.sd_model, latent, self.width, self.height)
            self.denoising_strength = 1.0 - (stopped_at + 1) / self.steps
            self.image_conditioning = txt2img_image_conditioning(shared.sd_model, noisy_latent, self.width, self.height)
            self.sampler = sd_samplers.create_sampler(self.sampler_name, shared.sd_model)
            samples = self.sampler.sample_img2img(self, latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))
            samples = self.sampler.sample_img2img(self, noisy_latent, x, self.c, self.uc, image_conditioning=self.image_conditioning, steps=max(1, self.steps - stopped_at - 1))

            self.denoising_strength = denoising_strength

@@ -823,6 +823,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if state.interrupted:
                break

            sd_models.reload_model_weights()  # model can be changed for example by refiner

            p.prompts = p.all_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.negative_prompts = p.all_negative_prompts[n * p.batch_size:(n + 1) * p.batch_size]
            p.seeds = p.all_seeds[n * p.batch_size:(n + 1) * p.batch_size]
@@ -862,10 +864,12 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                p.sampler = sd_samplers.create_sampler(p.sampler_name, p.sd_model)

                if have_refiner:
                    p.sampler.stop_at = max(1, int(shared.opts.sd_refiner_switch_at * p.steps - 1))

            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

            if opts.sd_vae_decode_method != 'Full':
@@ -1056,8 +1060,6 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        self.hr_uc = None

    def init(self, all_prompts, all_seeds, all_subseeds):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)

        if self.enable_hr:
            if self.hr_checkpoint_name:
                self.hr_checkpoint_info = sd_models.get_closet_checkpoint_match(self.hr_checkpoint_name)
@@ -1355,7 +1357,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        self.image_conditioning = None

    def init(self, all_prompts, all_seeds, all_subseeds):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
        crop_region = None

        image_mask = self.image_mask
+2 −0
Original line number Diff line number Diff line
@@ -276,6 +276,7 @@ class KDiffusionSampler:
        self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
        self.sampler_noises = None
        self.stop_at = None
        self.noisy_output = None
        self.eta = None
        self.config = None  # set by the function calling the constructor
        self.last_latent = None
@@ -297,6 +298,7 @@ class KDiffusionSampler:
        if opts.live_preview_content == "Combined":
            sd_samplers_common.store_latent(latent)
        self.last_latent = latent
        self.noisy_output = d['x']

        if self.stop_at is not None and step > self.stop_at:
            raise sd_samplers_common.InterruptedException