Commit 6fc12428 authored by CodeHatchling's avatar CodeHatchling
Browse files

Fixed issue where batched inpainting (batch size > 1) wouldn't work because of...

Fixed issue where batched inpainting (batch size > 1) wouldn't work because of mismatched tensor sizes. The 'already_decoded' decoded case should also be handled correctly (tested indirectly).
parent b32a334e
Loading
Loading
Loading
Loading
+15 −8
Original line number Diff line number Diff line
@@ -883,14 +883,21 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if getattr(samples_ddim, 'already_decoded', False):
                x_samples_ddim = samples_ddim
                # todo: generate adaptive masks based on pixel differences.
                # if p.masks_for_overlay is used, it will already be populated with masks
                if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
                    si.apply_masks(soft_inpainting=p.soft_inpainting,
                                   nmask=p.nmask,
                                   overlay_images=p.overlay_images,
                                   masks_for_overlay=p.masks_for_overlay,
                                   width=p.width,
                                   height=p.height,
                                   paste_to=p.paste_to)
            else:
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method

                # Generate the mask(s) based on similarity between the original and denoised latent vectors
                if getattr(p, "image_mask", None) is not None and getattr(p, "soft_inpainting", None) is not None:
                    si.generate_adaptive_masks(latent_orig=p.init_latent,
                    si.apply_adaptive_masks(latent_orig=p.init_latent,
                                            latent_processed=samples_ddim,
                                            overlay_images=p.overlay_images,
                                            masks_for_overlay=p.masks_for_overlay,
+56 −10
Original line number Diff line number Diff line
@@ -25,26 +25,32 @@ def latent_blend(soft_inpainting, a, b, t):

    # NOTE: We use inplace operations wherever possible.

    one_minus_t = 1 - t
    # [4][w][h] to [1][4][w][h]
    t2 = t.unsqueeze(0)
    # [4][w][h] to [1][1][w][h] - the [4] seem redundant.
    t3 = t[0].unsqueeze(0).unsqueeze(0)

    one_minus_t2 = 1 - t2
    one_minus_t3 = 1 - t3

    # Linearly interpolate the image vectors.
    a_scaled = a * one_minus_t
    b_scaled = b * t
    a_scaled = a * one_minus_t2
    b_scaled = b * t2
    image_interp = a_scaled
    image_interp.add_(b_scaled)
    result_type = image_interp.dtype
    del a_scaled, b_scaled
    del a_scaled, b_scaled, t2, one_minus_t2

    # Calculate the magnitude of the interpolated vectors. (We will remove this magnitude.)
    # 64-bit operations are used here to allow large exponents.
    current_magnitude = torch.norm(image_interp, p=2, dim=1).to(torch.float64).add_(0.00001)
    current_magnitude = torch.norm(image_interp, p=2, dim=1, keepdim=True).to(torch.float64).add_(0.00001)

    # Interpolate the powered magnitudes, then un-power them (bring them back to a power of 1).
    a_magnitude = torch.norm(a, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t
    b_magnitude = torch.norm(b, p=2, dim=1).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t
    a_magnitude = torch.norm(a, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * one_minus_t3
    b_magnitude = torch.norm(b, p=2, dim=1, keepdim=True).to(torch.float64).pow_(soft_inpainting.inpaint_detail_preservation) * t3
    desired_magnitude = a_magnitude
    desired_magnitude.add_(b_magnitude).pow_(1 / soft_inpainting.inpaint_detail_preservation)
    del a_magnitude, b_magnitude, one_minus_t
    del a_magnitude, b_magnitude, t3, one_minus_t3

    # Change the linearly interpolated image vectors' magnitudes to the value we want.
    # This is the last 64-bit operation.
@@ -78,10 +84,11 @@ def get_modified_nmask(soft_inpainting, nmask, sigma):
    NOTE: "mask" is not used
    """
    import torch
    return torch.pow(nmask, (sigma ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)
    # todo: Why is sigma 2D? Both values are the same.
    return torch.pow(nmask, (sigma[0] ** soft_inpainting.mask_blend_power) * soft_inpainting.mask_blend_scale)


def generate_adaptive_masks(
def apply_adaptive_masks(
        latent_orig,
        latent_processed,
        overlay_images,
@@ -142,6 +149,45 @@ def generate_adaptive_masks(

        overlay_images[i] = image_masked.convert('RGBA')

def apply_masks(
        soft_inpainting,
        nmask,
        overlay_images,
        masks_for_overlay,
        width, height,
        paste_to):
    import torch
    import numpy as np
    import modules.processing as proc
    import modules.images as images
    from PIL import Image, ImageOps, ImageFilter

    converted_mask = nmask[0].float()
    converted_mask = torch.clamp(converted_mask, min=0, max=1).pow_(soft_inpainting.mask_blend_scale / 2)
    converted_mask = 255. * converted_mask
    converted_mask = converted_mask.cpu().numpy().astype(np.uint8)
    converted_mask = Image.fromarray(converted_mask)
    converted_mask = images.resize_image(2, converted_mask, width, height)
    converted_mask = proc.create_binary_mask(converted_mask, round=False)

    # Remove aliasing artifacts using a gaussian blur.
    converted_mask = converted_mask.filter(ImageFilter.GaussianBlur(radius=4))

    # Expand the mask to fit the whole image if needed.
    if paste_to is not None:
        converted_mask = proc.uncrop(converted_mask,
                                     (width, height),
                                     paste_to)

    for i, overlay_image in enumerate(overlay_images):
        masks_for_overlay[i] = converted_mask

        image_masked = Image.new('RGBa', (overlay_image.width, overlay_image.height))
        image_masked.paste(overlay_image.convert("RGBA").convert("RGBa"),
                           mask=ImageOps.invert(converted_mask.convert('L')))

        overlay_images[i] = image_masked.convert('RGBA')


# ------------------- Constants -------------------