Unverified Commit 8c32594d authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #14208 from CodeHatchling/soft-inpainting

Soft Inpainting
parents f3cc5f83 f1ff932c
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -791,3 +791,4 @@ def flatten(img, bgcolor):
        img = background

    return img.convert('RGB')
+67 −25
Original line number Diff line number Diff line
@@ -62,28 +62,35 @@ def apply_color_correction(correction, original_image):
    return image.convert('RGB')


def apply_overlay(image, paste_loc, index, overlays):
    if overlays is None or index >= len(overlays):
        return image

    overlay = overlays[index]

    if paste_loc is not None:
def uncrop(image, dest_size, paste_loc):
    x, y, w, h = paste_loc
        base_image = Image.new('RGBA', (overlay.width, overlay.height))
    base_image = Image.new('RGBA', dest_size)
    image = images.resize_image(1, image, w, h)
    base_image.paste(image, (x, y))
    image = base_image

    return image


def apply_overlay(image, paste_loc, overlay):
    if overlay is None:
        return image

    if paste_loc is not None:
        image = uncrop(image, (overlay.width, overlay.height), paste_loc)

    image = image.convert('RGBA')
    image.alpha_composite(overlay)
    image = image.convert('RGB')

    return image

def create_binary_mask(image):
def create_binary_mask(image, round=True):
    if image.mode == 'RGBA' and image.getextrema()[-1] != (255, 255):
        if round:
            image = image.split()[-1].convert("L").point(lambda x: 255 if x > 128 else 0)
        else:
            image = image.split()[-1].convert("L")
    else:
        image = image.convert('L')
    return image
@@ -308,7 +315,7 @@ class StableDiffusionProcessing:
            c_adm = torch.cat((c_adm, noise_level_emb), 1)
        return c_adm

    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None):
    def inpainting_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
        self.is_using_inpainting_conditioning = True

        # Handle the different mask inputs
@@ -320,8 +327,10 @@ class StableDiffusionProcessing:
                conditioning_mask = conditioning_mask.astype(np.float32) / 255.0
                conditioning_mask = torch.from_numpy(conditioning_mask[None, None])

                # Inpainting model uses a discretized mask as input, so we round to either 1.0 or 0.0
                if round_image_mask:
                    # Caller is requesting a discretized mask as input, so we round to either 1.0 or 0.0
                    conditioning_mask = torch.round(conditioning_mask)

        else:
            conditioning_mask = source_image.new_ones(1, 1, *source_image.shape[-2:])

@@ -345,7 +354,7 @@ class StableDiffusionProcessing:

        return image_conditioning

    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None):
    def img2img_image_conditioning(self, source_image, latent_image, image_mask=None, round_image_mask=True):
        source_image = devices.cond_cast_float(source_image)

        # HACK: Using introspection as the Depth2Image model doesn't appear to uniquely
@@ -357,7 +366,7 @@ class StableDiffusionProcessing:
            return self.edit_image_conditioning(source_image)

        if self.sampler.conditioning_key in {'hybrid', 'concat'}:
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask)
            return self.inpainting_image_conditioning(source_image, latent_image, image_mask=image_mask, round_image_mask=round_image_mask)

        if self.sampler.conditioning_key == "crossattn-adm":
            return self.unclip_image_conditioning(source_image)
@@ -867,6 +876,11 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                samples_ddim = p.sample(conditioning=p.c, unconditional_conditioning=p.uc, seeds=p.seeds, subseeds=p.subseeds, subseed_strength=p.subseed_strength, prompts=p.prompts)

            if p.scripts is not None:
                ps = scripts.PostSampleArgs(samples_ddim)
                p.scripts.post_sample(p, ps)
                samples_ddim = ps.samples

            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
            else:
@@ -922,13 +936,31 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                    pp = scripts.PostprocessImageArgs(image)
                    p.scripts.postprocess_image(p, pp)
                    image = pp.image

                mask_for_overlay = getattr(p, "mask_for_overlay", None)
                overlay_image = p.overlay_images[i] if getattr(p, "overlay_images", None) is not None and i < len(p.overlay_images) else None

                if p.scripts is not None:
                    ppmo = scripts.PostProcessMaskOverlayArgs(i, mask_for_overlay, overlay_image)
                    p.scripts.postprocess_maskoverlay(p, ppmo)
                    mask_for_overlay, overlay_image = ppmo.mask_for_overlay, ppmo.overlay_image

                if p.color_corrections is not None and i < len(p.color_corrections):
                    if save_samples and opts.save_images_before_color_correction:
                        image_without_cc = apply_overlay(image, p.paste_to, i, p.overlay_images)
                        image_without_cc = apply_overlay(image, p.paste_to, overlay_image)
                        images.save_image(image_without_cc, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-before-color-correction")
                    image = apply_color_correction(p.color_corrections[i], image)

                image = apply_overlay(image, p.paste_to, i, p.overlay_images)
                # If the intention is to show the output from the model
                # that is being composited over the original image,
                # we need to keep the original image around
                # and use it in the composite step.
                original_denoised_image = image.copy()

                if p.paste_to is not None:
                    original_denoised_image = uncrop(original_denoised_image, (overlay_image.width, overlay_image.height), p.paste_to)

                image = apply_overlay(image, p.paste_to, overlay_image)

                if save_samples:
                    images.save_image(image, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p)
@@ -938,16 +970,17 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                if opts.enable_pnginfo:
                    image.info["parameters"] = text
                output_images.append(image)
                if hasattr(p, 'mask_for_overlay') and p.mask_for_overlay:

                if mask_for_overlay is not None:
                    if opts.return_mask or opts.save_mask:
                        image_mask = p.mask_for_overlay.convert('RGB')
                        image_mask = mask_for_overlay.convert('RGB')
                        if save_samples and opts.save_mask:
                            images.save_image(image_mask, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask")
                        if opts.return_mask:
                            output_images.append(image_mask)

                    if opts.return_mask_composite or opts.save_mask_composite:
                        image_mask_composite = Image.composite(image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, p.mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
                        image_mask_composite = Image.composite(original_denoised_image.convert('RGBA').convert('RGBa'), Image.new('RGBa', image.size), images.resize_image(2, mask_for_overlay, image.width, image.height).convert('L')).convert('RGBA')
                        if save_samples and opts.save_mask_composite:
                            images.save_image(image_mask_composite, p.outpath_samples, "", p.seeds[i], p.prompts[i], opts.samples_format, info=infotext(i), p=p, suffix="-mask-composite")
                        if opts.return_mask_composite:
@@ -1351,6 +1384,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
    mask_blur_x: int = 4
    mask_blur_y: int = 4
    mask_blur: int = None
    mask_round: bool = True
    inpainting_fill: int = 0
    inpaint_full_res: bool = True
    inpaint_full_res_padding: int = 0
@@ -1396,7 +1430,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        if image_mask is not None:
            # image_mask is passed in as RGBA by Gradio to support alpha masks,
            # but we still want to support binary masks.
            image_mask = create_binary_mask(image_mask)
            image_mask = create_binary_mask(image_mask, round=self.mask_round)

            if self.inpainting_mask_invert:
                image_mask = ImageOps.invert(image_mask)
@@ -1503,6 +1537,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
            latmask = init_mask.convert('RGB').resize((self.init_latent.shape[3], self.init_latent.shape[2]))
            latmask = np.moveaxis(np.array(latmask, dtype=np.float32), 2, 0) / 255
            latmask = latmask[0]
            if self.mask_round:
                latmask = np.around(latmask)
            latmask = np.tile(latmask[None], (4, 1, 1))

@@ -1515,7 +1550,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
            elif self.inpainting_fill == 3:
                self.init_latent = self.init_latent * self.mask

        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask)
        self.image_conditioning = self.img2img_image_conditioning(image * 2 - 1, self.init_latent, image_mask, self.mask_round)

    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
        x = self.rng.next()
@@ -1527,7 +1562,14 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)

        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask
            blended_samples = samples * self.nmask + self.init_latent * self.mask

            if self.scripts is not None:
                mba = scripts.MaskBlendArgs(samples, self.nmask, self.init_latent, self.mask, blended_samples)
                self.scripts.on_mask_blend(self, mba)
                blended_samples = mba.blended_latent

            samples = blended_samples

        del x
        devices.torch_gc()
+70 −0
Original line number Diff line number Diff line
@@ -11,11 +11,31 @@ from modules import shared, paths, script_callbacks, extensions, script_loading,

AlwaysVisible = object()

class MaskBlendArgs:
    def __init__(self, current_latent, nmask, init_latent, mask, blended_latent, denoiser=None, sigma=None):
        self.current_latent = current_latent
        self.nmask = nmask
        self.init_latent = init_latent
        self.mask = mask
        self.blended_latent = blended_latent

        self.denoiser = denoiser
        self.is_final_blend = denoiser is None
        self.sigma = sigma

class PostSampleArgs:
    def __init__(self, samples):
        self.samples = samples

class PostprocessImageArgs:
    def __init__(self, image):
        self.image = image

class PostProcessMaskOverlayArgs:
    def __init__(self, index, mask_for_overlay, overlay_image):
        self.index = index
        self.mask_for_overlay = mask_for_overlay
        self.overlay_image = overlay_image

class PostprocessBatchListArgs:
    def __init__(self, images):
@@ -206,6 +226,25 @@ class Script:

        pass

    def on_mask_blend(self, p, mba: MaskBlendArgs, *args):
        """
        Called in inpainting mode when the original content is blended with the inpainted content.
        This is called at every step in the denoising process and once at the end.
        If is_final_blend is true, this is called for the final blending stage.
        Otherwise, denoiser and sigma are defined and may be used to inform the procedure.
        """

        pass

    def post_sample(self, p, ps: PostSampleArgs, *args):
        """
        Called after the samples have been generated,
        but before they have been decoded by the VAE, if applicable.
        Check getattr(samples, 'already_decoded', False) to test if the images are decoded.
        """

        pass

    def postprocess_image(self, p, pp: PostprocessImageArgs, *args):
        """
        Called for every image after it has been generated.
@@ -213,6 +252,13 @@ class Script:

        pass

    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs, *args):
        """
        Called for every image after it has been generated.
        """

        pass

    def postprocess(self, p, processed, *args):
        """
        This function is called after processing ends for AlwaysVisible scripts.
@@ -767,6 +813,22 @@ class ScriptRunner:
            except Exception:
                errors.report(f"Error running postprocess_batch_list: {script.filename}", exc_info=True)

    def post_sample(self, p, ps: PostSampleArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.post_sample(p, ps, *script_args)
            except Exception:
                errors.report(f"Error running post_sample: {script.filename}", exc_info=True)

    def on_mask_blend(self, p, mba: MaskBlendArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.on_mask_blend(p, mba, *script_args)
            except Exception:
                errors.report(f"Error running post_sample: {script.filename}", exc_info=True)

    def postprocess_image(self, p, pp: PostprocessImageArgs):
        for script in self.alwayson_scripts:
            try:
@@ -775,6 +837,14 @@ class ScriptRunner:
            except Exception:
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)

    def postprocess_maskoverlay(self, p, ppmo: PostProcessMaskOverlayArgs):
        for script in self.alwayson_scripts:
            try:
                script_args = p.script_args[script.args_from:script.args_to]
                script.postprocess_maskoverlay(p, ppmo, *script_args)
            except Exception:
                errors.report(f"Error running postprocess_image: {script.filename}", exc_info=True)

    def before_component(self, component, **kwargs):
        for callback, script in self.on_before_component_elem_id.get(kwargs.get("elem_id"), []):
            try:
+19 −2
Original line number Diff line number Diff line
@@ -56,6 +56,9 @@ class CFGDenoiser(torch.nn.Module):
        self.sampler = sampler
        self.model_wrap = None
        self.p = None

        # NOTE: masking before denoising can cause the original latents to be oversmoothed
        # as the original latents do not have noise
        self.mask_before_denoising = False

    @property
@@ -105,8 +108,21 @@ class CFGDenoiser(torch.nn.Module):

        assert not is_edit_model or all(len(conds) == 1 for conds in conds_list), "AND is not supported for InstructPix2Pix checkpoint (unless using Image CFG scale = 1.0)"

        # If we use masks, blending between the denoised and original latent images occurs here.
        def apply_blend(current_latent):
            blended_latent = current_latent * self.nmask + self.init_latent * self.mask

            if self.p.scripts is not None:
                from modules import scripts
                mba = scripts.MaskBlendArgs(current_latent, self.nmask, self.init_latent, self.mask, blended_latent, denoiser=self, sigma=sigma)
                self.p.scripts.on_mask_blend(self.p, mba)
                blended_latent = mba.blended_latent

            return blended_latent

        # Blend in the original latents (before)
        if self.mask_before_denoising and self.mask is not None:
            x = self.init_latent * self.mask + self.nmask * x
            x = apply_blend(x)

        batch_size = len(conds_list)
        repeats = [len(conds_list[i]) for i in range(batch_size)]
@@ -207,8 +223,9 @@ class CFGDenoiser(torch.nn.Module):
        else:
            denoised = self.combine_denoised(x_out, conds_list, uncond, cond_scale)

        # Blend in the original latents (after)
        if not self.mask_before_denoising and self.mask is not None:
            denoised = self.init_latent * self.mask + self.nmask * denoised
            denoised = apply_blend(denoised)

        self.sampler.last_latent = self.get_pred_x0(torch.cat([x_in[i:i + 1] for i in denoised_image_indexes]), torch.cat([x_out[i:i + 1] for i in denoised_image_indexes]), sigma)

+747 −0

File added.

Preview size limit exceeded, changes collapsed.