Commit 2abc4178 authored by CodeHatchling's avatar CodeHatchling
Browse files

Re-implemented soft inpainting via a script. Also fixed some mistakes with the...

Re-implemented soft inpainting via a script. Also fixed some mistakes with the previous hooks, removed unnecessary formatting changes, removed code that I had forgotten to.
parent ac457891
Loading
Loading
Loading
Loading
+10 −13
Original line number Diff line number Diff line
@@ -879,14 +879,13 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.scripts is not None:
                ps = scripts.PostSampleArgs(samples_ddim)
                p.scripts.post_sample(p, ps)
                samples_ddim = pp.samples
                samples_ddim = ps.samples

            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method

                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

            x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -944,7 +943,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                if p.scripts is not None:
                    ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
                    p.scripts.postprocess_maskoverlay(p, ppmo)
                    mask_for_overlay, overlay_image = pp.mask_for_overlay, pp.overlay_image
                    mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image

                if p.color_corrections is not None and i < len(p.color_corrections):
                    if save_samples and opts.save_images_before_color_correction:
@@ -959,7 +958,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                original_denoised_image = image.copy()

                if p.paste_to is not None:
                    original_denoised_image = uncrop(original_denoised_image, (p.overlay_image.width, p.overlay_image.height), p.paste_to)
                    original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)

                image = apply_overlay(image, p.paste_to, overlay_image)

@@ -1512,9 +1511,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
            if self.overlay_images is not None:
                self.overlay_images = self.overlay_images * self.batch_size

            if self.masks_for_overlay is not None:
                self.masks_for_overlay = self.masks_for_overlay * self.batch_size

            if self.color_corrections is not None and len(self.color_corrections) == 1:
                self.color_corrections = self.color_corrections * self.batch_size

@@ -1565,10 +1561,11 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):

        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)

        if self.mask is not None:
            blended_samples = samples * self.nmask + self.init_latent * self.mask

            if self.scripts is not None:
            mba = scripts.MaskBlendArgs(self, samples, self.nmask, self.init_latent, self.mask, blended_samples, sigma=None, is_final_blend=True)
                mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
                self.scripts.on_mask_blend(self, mba)
                blended_samples = mba.blended_latent

+2 −2
Original line number Diff line number Diff line
@@ -12,12 +12,12 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,
AlwaysVisible = object()

class MaskBlendArgs:
    def __init__(self, current_latent, nmask, init_latent, mask, blended_samples, denoiser=None, sigma=None):
    def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
        self.current_latent = current_latent
        self.nmask = nmask
        self.init_latent = init_latent
        self.mask = mask
        self.blended_samples = blended_samples
        self.blended_latent = blended_latent

        self.denoiser = denoiser
        self.is_final_blend = denoiser is None
+401 −0
Original line number Diff line number Diff line
import gradio as gr
from modules.ui_components import InputAccordion
import modules.scripts as scripts


class SoftInpaintingSettings:
    def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation):
        self.mask_blend_power = mask_blend_power
@@ -46,8 +51,10 @@ def latent_blend(soft_inpainting, a, b, t):
    current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)

    # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3
    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3
    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
        soft_inpainting.inpaint_detail_preservation) * one_minus_t3
    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
        soft_inpainting.inpaint_detail_preservation) * t3
    desired_magnitude = a_magnitude
    desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
    del a_magnitude, b_magnitude, t3, one_minus_t3
@@ -84,15 +91,13 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
    NOTE: "mask" is not used
    """
    import torch
    # todo: Why is sigma 2D? Both values are the same.
    return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
    return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)


def apply_adaptive_masks(
        latent_orig,
        latent_processed,
        overlay_images,
        masks_for_overlay,
        width, height,
        paste_to):
    import torch
@@ -112,6 +117,8 @@ def apply_adaptive_masks(

    kernel, kernel_center = images.get_gaussian_kernel(stddev_radius=1.5, max_radius=2)

    masks_for_overlay = []

    for i, (distance_map, overlay_image) in enumerate(zip(latent_distance, overlay_images)):
        converted_mask = distance_map.float().cpu().numpy()
        converted_mask = images.weighted_histogram_filter(converted_mask, kernel, kernel_center,
@@ -141,7 +148,7 @@ def apply_adaptive_masks(
                                         (overlay_image.width, overlay_image.height),
                                         paste_to)

        masks_for_overlay[i] = converted_mask
        masks_for_overlay.append(converted_mask)

        image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
        image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
@@ -149,11 +156,13 @@ def apply_adaptive_masks(

        overlay_images[i] = image_masked.convert('RGBA')

    return masks_for_overlay


def apply_masks(
        soft_inpainting,
        nmask,
        overlay_images,
        masks_for_overlay,
        width, height,
        paste_to):
    import torch
@@ -179,6 +188,8 @@ def apply_masks(
                                     (width, height),
                                     paste_to)

    masks_for_overlay = []

    for i, overlay_image in enumerate(overlay_images):
        masks_for_overlay[i] = converted_mask

@@ -188,6 +199,8 @@ def apply_masks(

        overlay_images[i] = image_masked.convert('RGBA')

    return masks_for_overlay


# ------------------- Constants -------------------

@@ -219,12 +232,21 @@ el_ids = SoftInpaintingSettings(
    "inpaint_detail_preservation")


# ------------------- UI -------------------
class Script(scripts.Script):

    def __init__(self):
        self.masks_for_overlay = None
        self.overlay_images = None

def gradio_ui():
    import gradio as gr
    from modules.ui_components import InputAccordion
    def title(self):
        return "Soft Inpainting"

    def show(self, is_img2img):
        return scripts.AlwaysVisible if is_img2img else False

    def ui(self, is_img2img):
        if not is_img2img:
            return

        with InputAccordion(False, label=enabled_ui_label, elem_id=enabled_el_id) as soft_inpainting_enabled:
            with gr.Group():
@@ -292,17 +314,88 @@ def gradio_ui():
                        - **High values**: Stronger contrast, may over-saturate colors.
                        """)

    return (
        [
            soft_inpainting_enabled,
            result.mask_blend_power,
            result.mask_blend_scale,
            result.inpaint_detail_preservation
        ],
        [
            (soft_inpainting_enabled, enabled_gen_param_label),
        self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
                                (result.mask_blend_power, gen_param_labels.mask_blend_power),
                                (result.mask_blend_scale, gen_param_labels.mask_blend_scale),
            (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)
        ]
    )
                                (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)]

        self.paste_field_names = []
        for _, field_name in self.infotext_fields:
            self.paste_field_names.append(field_name)

        return [soft_inpainting_enabled,
                result.mask_blend_power,
                result.mask_blend_scale,
                result.inpaint_detail_preservation]

    def process(self, p, enabled, power, scale, detail_preservation):
        if not enabled:
            return

        # Shut off the rounding it normally does.
        p.mask_round = False

        settings = SoftInpaintingSettings(power, scale, detail_preservation)

        # p.extra_generation_params["Mask rounding"] = False
        settings.add_generation_params(p.extra_generation_params)

    def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation):
        if not enabled:
            return

        if mba.sigma is None:
            mba.blended_latent = mba.current_latent
            return

        settings = SoftInpaintingSettings(power, scale, detail_preservation)

        # todo: Why is sigma 2D? Both values are the same.
        mba.blended_latent = latent_blend(settings,
                                          mba.init_latent,
                                          mba.current_latent,
                                          get_modified_nmask(settings, mba.nmask, mba.sigma[0]))

    def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation):
        if not enabled:
            return

        settings = SoftInpaintingSettings(power, scale, detail_preservation)

        from modules import images
        from modules.shared import opts

        # since the original code puts holes in the existing overlay images,
        # we have to rebuild them.
        self.overlay_images = []
        for img in p.init_images:

            image = images.flatten(img, opts.img2img_background_color)

            if p.paste_to is None and p.resize_mode != 3:
                image = images.resize_image(p.resize_mode, image, p.width, p.height)

            self.overlay_images.append(image.convert('RGBA'))

        if getattr(ps.samples, 'already_decoded', False):
            self.masks_for_overlay = apply_masks(soft_inpainting=settings,
                                                 nmask=p.nmask,
                                                 overlay_images=self.overlay_images,
                                                 width=p.width,
                                                 height=p.height,
                                                 paste_to=p.paste_to)
        else:
            self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent,
                                                          latent_processed=ps.samples,
                                                          overlay_images=self.overlay_images,
                                                          width=p.width,
                                                          height=p.height,
                                                          paste_to=p.paste_to)


    def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation):
        if not enabled:
            return

        ppmo.mask_for_overlay = self.masks_for_overlay[ppmo.index]
        ppmo.overlay_image = self.overlay_images[ppmo.index]
 No newline at end of file