Commit 6e2ce4e7 authored by random_thoughtss's avatar random_thoughtss
Browse files

Added image conditioning to latent upscale.

Only comuted  if the mask weight is not 1.0 to avoid extra memory.
Also includes some code cleanup.
parent 44ab954f
Loading
Loading
Loading
Loading
+11 −18
Original line number Diff line number Diff line
@@ -134,11 +134,7 @@ class StableDiffusionProcessing():
            # Dummy zero conditioning if we're not using inpainting model.
            # Still takes up a bit of memory, but no encoder call.
            # Pretty sure we can just make this a 1x1 image since its not going to be used besides its batch size.
            return torch.zeros(
                x.shape[0], 5, 1, 1, 
                dtype=x.dtype, 
                device=x.device
            )
            return x.new_zeros(x.shape[0], 5, 1, 1)

        height = height or self.height
        width = width or self.width
@@ -156,11 +152,7 @@ class StableDiffusionProcessing():
    def img2img_image_conditioning(self, source_image, latent_image, image_mask = None):
        if self.sampler.conditioning_key not in {'hybrid', 'concat'}:
            # Dummy zero conditioning if we're not using inpainting model.
            return torch.zeros(
                latent_image.shape[0], 5, 1, 1,
                dtype=latent_image.dtype,
                device=latent_image.device
            )
            return latent_image.new_zeros(latent_image.shape[0], 5, 1, 1)

        # Handle the different mask inputs
        if image_mask is not None:
@@ -174,11 +166,10 @@ class StableDiffusionProcessing():
                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                conditioning_mask = torch.round(conditioning_mask)
        else:
            conditioning_mask = torch.ones(1, 1, *source_image.shape[-2:])
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])

        # Create another latent image, this time with a masked version of the original input.
        # Smoothly interpolate between the masked and unmasked latent conditioning image using a parameter.
        conditioning_mask = conditioning_mask.to(source_image.device)
        conditioning_image = torch.lerp(
            source_image,
            source_image * (1.0 - conditioning_mask),
@@ -653,6 +644,12 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

        if opts.use_scale_latent_for_hires_fix:
            samples = torch.nn.functional.interpolate(samples, size=(self.height // opt_f, self.width // opt_f), mode="bilinear")
            
            # Avoid making the inpainting conditioning unless necessary as 
            # this does need some extra compute to decode / encode the image again.
            if getattr(self, "inpainting_mask_weight", shared.opts.inpainting_mask_weight) < 1.0:
                image_conditioning = self.img2img_image_conditioning(decode_first_stage(self.sd_model, samples), samples)
            else:
                image_conditioning = self.txt2img_image_conditioning(samples)

        else:
@@ -675,11 +672,7 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

            samples = self.sd_model.get_first_stage_encoding(self.sd_model.encode_first_stage(decoded_samples))

            image_conditioning = self.img2img_image_conditioning(
                decoded_samples, 
                samples, 
                decoded_samples.new_ones(decoded_samples.shape[0], 1, decoded_samples.shape[2], decoded_samples.shape[3])
            )
            image_conditioning = self.img2img_image_conditioning(decoded_samples, samples)

        shared.state.nextjob()