Commit f284ae23 authored by CodeHatchling's avatar CodeHatchling
Browse files

Added parameters for the composite stage, fixed batched generation.

parent 0ef4a4cb
Loading
Loading
Loading
Loading
+155 −43
Original line number Diff line number Diff line
@@ -6,22 +6,34 @@ import modules.scripts as scripts


class SoftInpaintingSettings:
    def __init__(self, mask_blend_power, mask_blend_scale, inpaint_detail_preservation):
    def __init__(self,
                 mask_blend_power,
                 mask_blend_scale,
                 inpaint_detail_preservation,
                 composite_mask_influence,
                 composite_difference_threshold,
                 composite_difference_contrast):
        self.mask_blend_power = mask_blend_power
        self.mask_blend_scale = mask_blend_scale
        self.inpaint_detail_preservation = inpaint_detail_preservation
        self.composite_mask_influence = composite_mask_influence
        self.composite_difference_threshold = composite_difference_threshold
        self.composite_difference_contrast = composite_difference_contrast

    def add_generation_params(self, dest):
        dest[enabled_gen_param_label] = True
        dest[gen_param_labels.mask_blend_power] = self.mask_blend_power
        dest[gen_param_labels.mask_blend_scale] = self.mask_blend_scale
        dest[gen_param_labels.inpaint_detail_preservation] = self.inpaint_detail_preservation
        dest[gen_param_labels.composite_mask_influence] = self.composite_mask_influence
        dest[gen_param_labels.composite_difference_threshold] = self.composite_difference_threshold
        dest[gen_param_labels.composite_difference_contrast] = self.composite_difference_contrast


# ------------------- Methods -------------------


def latent_blend(soft_inpainting, a, b, t):
def latent_blend(settings, a, b, t):
    """
    Interpolates two latent image representations according to the parameter t,
    where the interpolated vectors' magnitudes are also interpolated separately.
@@ -54,11 +66,11 @@ def latent_blend(soft_inpainting, a, b, t):

    # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
        soft_inpainting.inpaint_detail_preservation) * one_minus_t3
        settings.inpaint_detail_preservation) * one_minus_t3
    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(
        soft_inpainting.inpaint_detail_preservation) * t3
        settings.inpaint_detail_preservation) * t3
    desired_magnitude = a_magnitude
    desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
    desired_magnitude.add_(b_magnitude).pow_(1 / settings.inpaint_detail_preservation)
    del a_magnitude, b_magnitude, t3, one_minus_t3

    # Change the linearly interpolated image vectors' magnitudes to the value we want.
@@ -77,7 +89,7 @@ def latent_blend(soft_inpainting, a, b, t):
    return image_interp_scaled


def get_modified_nmask(soft_inpainting, nmask, sigma):
def get_modified_nmask(settings, nmask, sigma):
    """
    Converts a negative mask representing the transparency of the original latent vectors being overlayed
    to a mask that is scaled according to the denoising strength for this step.
@@ -93,10 +105,12 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
    NOTE: "mask" is not used
    """
    import torch
    return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
    return torch.pow(nmask, (sigma ** settings.mask_blend_power) * settings.mask_blend_scale)


def apply_adaptive_masks(
        settings:SoftInpaintingSettings,
        nmask,
        latent_orig,
        latent_processed,
        overlay_images,
@@ -108,11 +122,13 @@ def apply_adaptive_masks(
    from PIL import Image, ImageOps, ImageFilter

    # TODO: Bias the blending according to the latent mask, add adjustable parameter for bias control.
    # latent_mask = p.nmask[0].float().cpu()
    latent_mask = nmask[0].float()
    # convert the original mask into a form we use to scale distances for thresholding
    # mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (p.mask_blend_scale / 2))
    # mask_scalar = mask_scalar / (1.00001-mask_scalar)
    # mask_scalar = mask_scalar.numpy()
    mask_scalar = 1-(torch.clamp(latent_mask, min=0, max=1) ** (settings.mask_blend_scale / 2))
    mask_scalar = (0.5 * (1-settings.composite_mask_influence)
                   + mask_scalar * settings.composite_mask_influence)
    mask_scalar = mask_scalar / (1.00001-mask_scalar)
    mask_scalar = mask_scalar.cpu().numpy()

    latent_distance = torch.norm(latent_processed - latent_orig, p=2, dim=1)

@@ -128,10 +144,10 @@ def apply_adaptive_masks(
                                                          percentile_min=0.25, percentile_max=0.75, min_width=1)

        # The distance at which opacity of original decreases to 50%
        # half_weighted_distance = 1  # * mask_scalar
        # converted_mask = converted_mask / half_weighted_distance
        half_weighted_distance = settings.composite_difference_threshold * mask_scalar
        converted_mask = converted_mask / half_weighted_distance

        converted_mask = 1 / (1 + converted_mask ** 2)
        converted_mask = 1 / (1 + converted_mask ** settings.composite_difference_contrast)
        converted_mask = smootherstep(converted_mask)
        converted_mask = 1 - converted_mask
        converted_mask = 255. * converted_mask
@@ -161,7 +177,7 @@ def apply_adaptive_masks(


def apply_masks(
        soft_inpainting,
        settings,
        nmask,
        overlay_images,
        width, height,
@@ -172,7 +188,7 @@ def apply_masks(
    from PIL import Image, ImageOps, ImageFilter

    converted_mask = nmask[0].float()
    converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2)
    converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(settings.mask_blend_scale / 2)
    converted_mask = 255. * converted_mask
    converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
    converted_mask = Image.fromarray(converted_mask)
@@ -395,7 +411,7 @@ def get_gaussian_kernel(stddev_radius=1.0, max_radius=2):
# ------------------- Constants -------------------


default = SoftInpaintingSettings(1, 0.5, 4)
default = SoftInpaintingSettings(1, 0.5, 4, 0, 0.5, 2)

enabled_ui_label = "Soft inpainting"
enabled_gen_param_label = "Soft inpainting enabled"
@@ -404,25 +420,37 @@ enabled_el_id = "soft_inpainting_enabled"
ui_labels = SoftInpaintingSettings(
    "Schedule bias",
    "Preservation strength",
    "Transition contrast boost")
    "Transition contrast boost",
    "Mask influence",
    "Difference threshold",
    "Difference contrast")

ui_info = SoftInpaintingSettings(
    "Shifts when preservation of original content occurs during denoising.",
    "How strongly partially masked content should be preserved.",
    "Amplifies the contrast that may be lost in partially masked regions.")
    "Amplifies the contrast that may be lost in partially masked regions.",
    "How strongly the original mask should bias the difference threshold.",
    "How much an image region can change before the original pixels are not blended in anymore.",
    "How sharp the transition should be between blended and not blended.")

gen_param_labels = SoftInpaintingSettings(
    "Soft inpainting schedule bias",
    "Soft inpainting preservation strength",
    "Soft inpainting transition contrast boost")
    "Soft inpainting transition contrast boost",
    "Soft inpainting mask influence",
    "Soft inpainting difference threshold",
    "Soft inpainting difference contrast")

el_ids = SoftInpaintingSettings(
    "mask_blend_power",
    "mask_blend_scale",
    "inpaint_detail_preservation")
    "inpaint_detail_preservation",
    "composite_mask_influence",
    "composite_difference_threshold",
    "composite_difference_contrast")


# -----
# ------------------- Script -------------------


class Script(scripts.Script):
@@ -449,28 +477,62 @@ class Script(scripts.Script):
                    **High _Mask blur_** values are recommended!
                    """)

                result = SoftInpaintingSettings(
                power = \
                    gr.Slider(label=ui_labels.mask_blend_power,
                              info=ui_info.mask_blend_power,
                              minimum=0,
                              maximum=8,
                              step=0.1,
                              value=default.mask_blend_power,
                              elem_id=el_ids.mask_blend_power),
                              elem_id=el_ids.mask_blend_power)
                scale = \
                    gr.Slider(label=ui_labels.mask_blend_scale,
                              info=ui_info.mask_blend_scale,
                              minimum=0,
                              maximum=8,
                              step=0.05,
                              value=default.mask_blend_scale,
                              elem_id=el_ids.mask_blend_scale),
                              elem_id=el_ids.mask_blend_scale)
                detail = \
                    gr.Slider(label=ui_labels.inpaint_detail_preservation,
                              info=ui_info.inpaint_detail_preservation,
                              minimum=1,
                              maximum=32,
                              step=0.5,
                              value=default.inpaint_detail_preservation,
                              elem_id=el_ids.inpaint_detail_preservation))
                              elem_id=el_ids.inpaint_detail_preservation)

                gr.Markdown(
                    """
                    ### Pixel Composite Settings
                    """)

                mask_inf = \
                    gr.Slider(label=ui_labels.composite_mask_influence,
                              info=ui_info.composite_mask_influence,
                              minimum=0,
                              maximum=1,
                              step=0.05,
                              value=default.composite_mask_influence,
                              elem_id=el_ids.composite_mask_influence)

                dif_thresh = \
                    gr.Slider(label=ui_labels.composite_difference_threshold,
                              info=ui_info.composite_difference_threshold,
                              minimum=0,
                              maximum=8,
                              step=0.25,
                              value=default.composite_difference_threshold,
                              elem_id=el_ids.composite_difference_threshold)

                dif_contr = \
                    gr.Slider(label=ui_labels.composite_difference_contrast,
                              info=ui_info.composite_difference_contrast,
                              minimum=0,
                              maximum=8,
                              step=0.25,
                              value=default.composite_difference_contrast,
                              elem_id=el_ids.composite_difference_contrast)

                with gr.Accordion("Help", open=False):
                    gr.Markdown(
@@ -507,41 +569,86 @@ class Script(scripts.Script):
                        - **High values**: Stronger contrast, may over-saturate colors.
                        """)

                    gr.Markdown(
                        """
                        ## Pixel Composite Settings
                        
                        Masks are generated based on how much a part of the image changed after denoising.
                        These masks are used to blend the original and final images together.
                        If the difference is low, the original pixels are used instead of the pixels returned by the inpainting process.
                        """)

                    gr.Markdown(
                        f"""
                        ### {ui_labels.composite_mask_influence}

                        This parameter controls how much the mask should bias this sensitivity to difference.

                        - **0**: Ignore the mask, only consider differences in image content.
                        - **1**: Follow the mask closely despite image content changes.
                        """)

                    gr.Markdown(
                        f"""
                        ### {ui_labels.composite_difference_threshold}

                        This value represents the difference at which the opacity of the original pixels will have less than 50% opacity.

                        - **Low values**: Two images patches must be almost the same in order to retain original pixels.
                        - **High values**: Two images patches can be very different and still retain original pixels.
                        """)

                    gr.Markdown(
                        f"""
                        ### {ui_labels.composite_difference_contrast}

                        This value represents the difference at which the opacity of the original pixels will have less than 50% opacity.

                        - **Low values**: Two images patches must be almost the same in order to retain original pixels.
                        - **High values**: Two images patches can be very different and still retain original pixels.
                        """)

        self.infotext_fields = [(soft_inpainting_enabled, enabled_gen_param_label),
                                (result.mask_blend_power, gen_param_labels.mask_blend_power),
                                (result.mask_blend_scale, gen_param_labels.mask_blend_scale),
                                (result.inpaint_detail_preservation, gen_param_labels.inpaint_detail_preservation)]
                                (power, gen_param_labels.mask_blend_power),
                                (scale, gen_param_labels.mask_blend_scale),
                                (detail, gen_param_labels.inpaint_detail_preservation),
                                (mask_inf, gen_param_labels.composite_mask_influence),
                                (dif_thresh, gen_param_labels.composite_difference_threshold),
                                (dif_contr, gen_param_labels.composite_difference_contrast)]

        self.paste_field_names = []
        for _, field_name in self.infotext_fields:
            self.paste_field_names.append(field_name)

        return [soft_inpainting_enabled,
                result.mask_blend_power,
                result.mask_blend_scale,
                result.inpaint_detail_preservation]

    def process(self, p, enabled, power, scale, detail_preservation):
                power,
                scale,
                detail,
                mask_inf,
                dif_thresh,
                dif_contr]

    def process(self, p, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
        if not enabled:
            return

        # Shut off the rounding it normally does.
        p.mask_round = False

        settings = SoftInpaintingSettings(power, scale, detail_preservation)
        settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)

        # p.extra_generation_params["Mask rounding"] = False
        settings.add_generation_params(p.extra_generation_params)

    def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation):
    def on_mask_blend(self, p, mba: scripts.MaskBlendArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
        if not enabled:
            return

        if mba.sigma is None:
        if mba.is_final_blend:
            mba.blended_latent = mba.current_latent
            return

        settings = SoftInpaintingSettings(power, scale, detail_preservation)
        settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)

        # todo: Why is sigma 2D? Both values are the same.
        mba.blended_latent = latent_blend(settings,
@@ -549,11 +656,11 @@ class Script(scripts.Script):
                                          mba.current_latent,
                                          get_modified_nmask(settings, mba.nmask, mba.sigma[0]))

    def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation):
    def post_sample(self, p, ps: scripts.PostSampleArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
        if not enabled:
            return

        settings = SoftInpaintingSettings(power, scale, detail_preservation)
        settings = SoftInpaintingSettings(power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr)

        from modules import images
        from modules.shared import opts
@@ -570,15 +677,20 @@ class Script(scripts.Script):

            self.overlay_images.append(image.convert('RGBA'))

        if len(p.init_images) == 1:
            self.overlay_images = self.overlay_images * p.batch_size

        if getattr(ps.samples, 'already_decoded', False):
            self.masks_for_overlay = apply_masks(soft_inpainting=settings,
            self.masks_for_overlay = apply_masks(settings=settings,
                                                 nmask=p.nmask,
                                                 overlay_images=self.overlay_images,
                                                 width=p.width,
                                                 height=p.height,
                                                 paste_to=p.paste_to)
        else:
            self.masks_for_overlay = apply_adaptive_masks(latent_orig=p.init_latent,
            self.masks_for_overlay = apply_adaptive_masks(settings=settings,
                                                          nmask=p.nmask,
                                                          latent_orig=p.init_latent,
                                                          latent_processed=ps.samples,
                                                          overlay_images=self.overlay_images,
                                                          width=p.width,
@@ -586,7 +698,7 @@ class Script(scripts.Script):
                                                          paste_to=p.paste_to)


    def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation):
    def postprocess_maskoverlay(self, p, ppmo: scripts.PostProcessMaskOverlayArgs, enabled, power, scale, detail_preservation, mask_inf, dif_thresh, dif_contr):
        if not enabled:
            return