Commit 976c1053 authored by CodeHatchling's avatar CodeHatchling
Browse files

Cleaned up code, moved main code contributions into soft_inpainting.py

parent 259d33c3
Loading
Loading
Loading
Loading
+7 −49
Original line number Diff line number Diff line
@@ -892,55 +892,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:

                # Generate the mask(s) based on similarity between the original and denoised latent vectors
                if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
                    # latent_mask = p.nmask[0].float().cpu()

                    # convert the original mask into a form we use to scale distances for thresholding
                    # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
                    # mask_scalar = mask_scalar / (1.00001-mask_scalar)
                    # mask_scalar = mask_scalar.numpy()

                    latent_orig = p.init_latent
                    latent_proc = samples_ddim
                    latent_distance = torch.norm(latent_proc - latent_orig, p=2, dim=1)

                    kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)

                    for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, p.overlay_images)):
                        converted_mask = distance_map.float().cpu().numpy()
                        converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
                                                       percentile_min=0.9, percentile_max=1, min_width=1)
                        converted_mask = images.weighted_histogram_filter(converted_mask,  kernel, kernel_center,
                                                       percentile_min=0.25, percentile_max=0.75, min_width=1)

                        # The distance at which opacity of original decreases to 50%
                        # half_weighted_distance = 1  # * mask_scalar
                        # converted_mask = converted_mask / half_weighted_distance

                        converted_mask = 1 / (1 + converted_mask ** 2)
                        converted_mask = images.smootherstep(converted_mask)
                        converted_mask = 1 - converted_mask
                        converted_mask = 255. * converted_mask
                        converted_mask = converted_mask.astype(np.uint8)
                        converted_mask = Image.fromarray(converted_mask)
                        converted_mask = images.resize_image(2, converted_mask, p.width, p.height)
                        converted_mask = create_binary_mask(converted_mask, round=False)

                        # Remove aliasing artifacts using a gaussian blur.
                        converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))

                        # Expand the mask to fit the whole image if needed.
                        if p.paste_to is not None:
                            converted_mask = uncrop(converted_mask,
                                                    (overlay_image.width, overlay_image.height),
                                                    p.paste_to)

                        p.masks_for_overlay[i] = converted_mask

                        image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
                        image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
                                           mask=ImageOps.invert(converted_mask.convert('L')))

                        p.overlay_images[i] = image_masked.convert('RGBA')
                    si.generate_adaptive_masks(latent_orig=p.init_latent,
                                               latent_processed=samples_ddim,
                                               overlay_images=p.overlay_images,
                                               masks_for_overlay=p.masks_for_overlay,
                                               width=p.width,
                                               height=p.height,
                                               paste_to=p.paste_to)

                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim,
                                                     target_device=devices.cpu,
+10 −74
Original line number Diff line number Diff line
@@ -94,76 +94,6 @@ class CFGDenoiser(torch.nn.Module):
        self.sampler.sampler_extra_args['uncond'] = uc

    def forward(self, x, sigma, uncond, cond, cond_scale, s_min_uncond, image_cond):
        def latent_blend(a, b, t, one_minus_t=None):

            """
            Interpolates two latent image representations according to the parameter t,
            where the interpolated vectors' magnitudes are also interpolated separately.
            The "detail_preservation" factor biases the magnitude interpolation towards
            the larger of the two magnitudes.
            """
            # NOTE: We use inplace operations wherever possible.

            if one_minus_t is None:
                one_minus_t = 1 - t

            if self.soft_inpainting is None:
                return a * one_minus_t + b * t

            # Linearly interpolate the image vectors.
            a_scaled = a * one_minus_t
            b_scaled = b * t
            image_interp = a_scaled
            image_interp.add_(b_scaled)
            result_type = image_interp.dtype
            del a_scaled, b_scaled

            # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
            # 64-bit operations are used here to allow large exponents.
            current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)

            # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
            a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * one_minus_t
            b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(self.soft_inpainting.inpaint_detail_preservation) * t
            desired_magnitude = a_magnitude
            desired_magnitude.add_(b_magnitude).pow_(1 / self.soft_inpainting.inpaint_detail_preservation)
            del a_magnitude, b_magnitude, one_minus_t

            # Change the linearly interpolated image vectors' magnitudes to the value we want.
            # This is the last 64-bit operation.
            image_interp_scaling_factor = desired_magnitude
            image_interp_scaling_factor.div_(current_magnitude)
            image_interp_scaled = image_interp
            image_interp_scaled.mul_(image_interp_scaling_factor)
            del current_magnitude
            del desired_magnitude
            del image_interp
            del image_interp_scaling_factor

            image_interp_scaled = image_interp_scaled.to(result_type)
            del result_type

            return image_interp_scaled

        def get_modified_nmask(nmask, _sigma):
            """
            Converts a negative mask representing the transparency of the original latent vectors being overlayed
            to a mask that is scaled according to the denoising strength for this step.

            Where:
                0 = fully opaque, infinite density, fully masked
                1 = fully transparent, zero density, fully unmasked

            We bring this transparency to a power, as this allows one to simulate N number of blending operations
            where N can be any positive real value. Using this one can control the balance of influence between
            the denoiser and the original latents according to the sigma value.

            NOTE: "mask" is not used
            """
            if self.soft_inpainting is None:
                return nmask

            return torch.pow(nmask, (_sigma ** self.soft_inpainting.mask_blend_power) * self.soft_inpainting.mask_blend_scale)

        if state.interrupted or state.skipped:
            raise sd_samplers_common.InterruptedException
@@ -184,9 +114,12 @@ class CFGDenoiser(torch.nn.Module):
        # Blend in the original latents (before)
        if self.mask_before_denoising and self.mask is not None:
            if self.soft_inpainting is None:
                x = latent_blend(self.init_latent, x, self.nmask, self.mask)
                x = self.init_latent * self.mask + self.nmask * x
            else:
                x = latent_blend(self.init_latent, x, get_modified_nmask(self.nmask, sigma))
                x = si.latent_blend(self.soft_inpainting,
                                    self.init_latent,
                                    x,
                                    si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))

        batch_size = len(conds_list)
        repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -290,9 +223,12 @@ class CFGDenoiser(torch.nn.Module):
        # Blend in the original latents (after)
        if not self.mask_before_denoising and self.mask is not None:
            if self.soft_inpainting is None:
                denoised = latent_blend(self.init_latent, denoised, self.nmask, self.mask)
                denoised = self.init_latent * self.mask + self.nmask * denoised
            else:
                denoised = latent_blend(self.init_latent, denoised, get_modified_nmask(self.nmask, sigma))
                denoised = si.latent_blend(self.soft_inpainting,
                                           self.init_latent,
                                           denoised,
                                           si.get_modified_nmask(self.soft_inpainting, self.nmask, sigma))

        self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

+157 −20
Original line number Diff line number Diff line
@@ -4,13 +4,6 @@ class SoftInpaintingSettings:
        self.mask_blend_scale = mask_blend_scale
        self.inpaint_detail_preservation = inpaint_detail_preservation

    def get_paste_fields(self):
        return [
            (self.mask_blend_power, gen_param_labels.mask_blend_power),
            (self.mask_blend_scale, gen_param_labels.mask_blend_scale),
            (self.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation),
        ]

    def add_generation_params(self, dest):
        dest[enabled_gen_param_label] = True
        dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
@@ -18,25 +11,169 @@ class SoftInpaintingSettings:
        dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation


# ------------------- Methods -------------------


def latent_blend(soft_inpainting, a, b, t):
    """
    Interpolates two latent image representations according to the parameter t,
    where the interpolated vectors' magnitudes are also interpolated separately.
    The "detail_preservation" factor biases the magnitude interpolation towards
    the larger of the two magnitudes.
    """
    import torch

    # NOTE: We use inplace operations wherever possible.

    one_minus_t = 1 - t

    # Linearly interpolate the image vectors.
    a_scaled = a * one_minus_t
    b_scaled = b * t
    image_interp = a_scaled
    image_interp.add_(b_scaled)
    result_type = image_interp.dtype
    del a_scaled, b_scaled

    # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
    # 64-bit operations are used here to allow large exponents.
    current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)

    # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
    a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t
    b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t
    desired_magnitude = a_magnitude
    desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
    del a_magnitude, b_magnitude, one_minus_t

    # Change the linearly interpolated image vectors' magnitudes to the value we want.
    # This is the last 64-bit operation.
    image_interp_scaling_factor = desired_magnitude
    image_interp_scaling_factor.div_(current_magnitude)
    image_interp_scaling_factor = image_interp_scaling_factor.to(result_type)
    image_interp_scaled = image_interp
    image_interp_scaled.mul_(image_interp_scaling_factor)
    del current_magnitude
    del desired_magnitude
    del image_interp
    del image_interp_scaling_factor
    del result_type

    return image_interp_scaled


def get_modified_nmask(soft_inpainting, nmask, sigma):
    """
    Converts a negative mask representing the transparency of the original latent vectors being overlayed
    to a mask that is scaled according to the denoising strength for this step.

    Where:
        0 = fully opaque, infinite density, fully masked
        1 = fully transparent, zero density, fully unmasked

    We bring this transparency to a power, as this allows one to simulate N number of blending operations
    where N can be any positive real value. Using this one can control the balance of influence between
    the denoiser and the original latents according to the sigma value.

    NOTE: "mask" is not used
    """
    import torch
    return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)


def generate_adaptive_masks(
        latent_orig,
        latent_processed,
        overlay_images,
        masks_for_overlay,
        width, height,
        paste_to):
    import torch
    import numpy as np
    import modules.processing as proc
    import modules.images as images
    from PIL import Image, ImageOps, ImageFilter

    # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
    # latent_mask = p.nmask[0].float().cpu()
    # convert the original mask into a form we use to scale distances for thresholding
    # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
    # mask_scalar = mask_scalar / (1.00001-mask_scalar)
    # mask_scalar = mask_scalar.numpy()

    latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)

    kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)

    for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
        converted_mask = distance_map.float().cpu().numpy()
        converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
                                                          percentile_min=0.9, percentile_max=1, min_width=1)
        converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
                                                          percentile_min=0.25, percentile_max=0.75, min_width=1)

        # The distance at which opacity of original decreases to 50%
        # half_weighted_distance = 1  # * mask_scalar
        # converted_mask = converted_mask / half_weighted_distance

        converted_mask = 1 / (1 + converted_mask ** 2)
        converted_mask = images.smootherstep(converted_mask)
        converted_mask = 1 - converted_mask
        converted_mask = 255. * converted_mask
        converted_mask = converted_mask.astype(np.uint8)
        converted_mask = Image.fromarray(converted_mask)
        converted_mask = images.resize_image(2, converted_mask, width, height)
        converted_mask = proc.create_binary_mask(converted_mask, round=False)

        # Remove aliasing artifacts using a gaussian blur.
        converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))

        # Expand the mask to fit the whole image if needed.
        if paste_to is not None:
            converted_mask = proc. uncrop(converted_mask,
                                    (overlay_image.width, overlay_image.height),
                                    paste_to)

        masks_for_overlay[i] = converted_mask

        image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
        image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
                           mask=ImageOps.invert(converted_mask.convert('L')))

        overlay_images[i] = image_masked.convert('RGBA')


# ------------------- Constants -------------------


default = SoftInpaintingSettings(1, 0.5, 4)

enabled_ui_label = "Soft inpainting"
enabled_gen_param_label = "Soft inpainting enabled"
enabled_el_id = "soft_inpainting_enabled"

default = SoftInpaintingSettings(1, 0.5, 4)
ui_labels = SoftInpaintingSettings("Schedule bias", "Preservation strength", "Transition contrast boost")
ui_labels = SoftInpaintingSettings(
    "Schedule bias",
    "Preservation strength",
    "Transition contrast boost")

ui_info = SoftInpaintingSettings(
    mask_blend_power="Shifts when preservation of original content occurs during denoising.",
                     # "Below 1: Stronger preservation near the end (with low sigma)\n"
                     # "1: Balanced (proportional to sigma)\n"
                     # "Above 1: Stronger preservation in the beginning (with high sigma)",
    mask_blend_scale="How strongly partially masked content should be preserved.",
                     # "Low values: Favors generated content.\n"
                     # "High values: Favors original content.",
    inpaint_detail_preservation="Amplifies the contrast that may be lost in partially masked regions.")

gen_param_labels = SoftInpaintingSettings("Soft inpainting schedule bias", "Soft inpainting preservation strength", "Soft inpainting transition contrast boost")
el_ids = SoftInpaintingSettings("mask_blend_power", "mask_blend_scale", "inpaint_detail_preservation")
    "Shifts when preservation of original content occurs during denoising.",
    "How strongly partially masked content should be preserved.",
    "Amplifies the contrast that may be lost in partially masked regions.")

gen_param_labels = SoftInpaintingSettings(
    "Soft inpainting schedule bias",
    "Soft inpainting preservation strength",
    "Soft inpainting transition contrast boost")

el_ids = SoftInpaintingSettings(
    "mask_blend_power",
    "mask_blend_scale",
    "inpaint_detail_preservation")


# ------------------- UI -------------------


def gradio_ui():
+0 −7
Original line number Diff line number Diff line
@@ -683,13 +683,6 @@ def create_ui():
                            with FormRow():
                                soft_inpainting = si.gradio_ui()


                            """
                                mask_blend_power = gr.Slider(label='Blending bias', minimum=0, maximum=8, step=0.1, value=1, elem_id="img2img_mask_blend_power")
                                mask_blend_scale = gr.Slider(label='Blending preservation', minimum=0, maximum=8, step=0.05, value=0.5, elem_id="img2img_mask_blend_scale")
                                inpaint_detail_preservation = gr.Slider(label='Blending contrast boost', minimum=1, maximum=32, step=0.5, value=4, elem_id="img2img_mask_blend_offset")
                            """

                            with FormRow():
                                inpainting_mask_invert = gr.Radio(label='Mask mode', choices=['Inpaint masked', 'Inpaint not masked'], value='Inpaint masked', type="index", elem_id="img2img_mask_mode")