Commit a7aa00d4 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Merge remote-tracking branch 'mk2/outpainting-mk2-batch-out'

parents 704036ff 1fc278bc
Loading
Loading
Loading
Loading
+80 −59
Original line number Diff line number Diff line
@@ -172,23 +172,22 @@ class Script(scripts.Script):
        if down > 0:
            down = target_h - init_img.height - up

        init_image = p.init_images[0]

        state.job_count = (1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0)

        def expand(init, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
        def expand(init, count, expand_pixels, is_left=False, is_right=False, is_top=False, is_bottom=False):
            is_horiz = is_left or is_right
            is_vert = is_top or is_bottom
            pixels_horiz = expand_pixels if is_horiz else 0
            pixels_vert = expand_pixels if is_vert else 0

            res_w = init.width + pixels_horiz
            res_h = init.height + pixels_vert
            images_to_process = []
            output_images = []
            for n in range(count):
                res_w = init[n].width + pixels_horiz
                res_h = init[n].height + pixels_vert
                process_res_w = math.ceil(res_w / 64) * 64
                process_res_h = math.ceil(res_h / 64) * 64

                img = Image.new("RGB", (process_res_w, process_res_h))
            img.paste(init, (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
                img.paste(init[n], (pixels_horiz if is_left else 0, pixels_vert if is_top else 0))
                mask = Image.new("RGB", (process_res_w, process_res_h), "white")
                draw = ImageDraw.Draw(mask)
                draw.rectangle((
@@ -201,26 +200,27 @@ class Script(scripts.Script):
                np_image = (np.asarray(img) / 255.0).astype(np.float64)
                np_mask = (np.asarray(mask) / 255.0).astype(np.float64)
                noised = get_matched_noise(np_image, np_mask, noise_q, color_variation)
            out = Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB")
                output_images.append(Image.fromarray(np.clip(noised * 255., 0., 255.).astype(np.uint8), mode="RGB"))

            target_width = min(process_width, init.width + pixels_horiz) if is_horiz else img.width
            target_height = min(process_height, init.height + pixels_vert) if is_vert else img.height
                target_width = min(process_width, init[n].width + pixels_horiz) if is_horiz else img.width
                target_height = min(process_height, init[n].height + pixels_vert) if is_vert else img.height
                p.width = target_width if is_horiz else img.width
                p.height = target_height if is_vert else img.height

                crop_region = (
                0 if is_left else out.width - target_width,
                0 if is_top else out.height - target_height,
                target_width if is_left else out.width,
                target_height if is_top else out.height,
                    0 if is_left else output_images[n].width - target_width,
                    0 if is_top else output_images[n].height - target_height,
                    target_width if is_left else output_images[n].width,
                    target_height if is_top else output_images[n].height,
                )

            image_to_process = out.crop(crop_region)
                mask = mask.crop(crop_region)

            p.width = target_width if is_horiz else img.width
            p.height = target_height if is_vert else img.height
            p.init_images = [image_to_process]
                p.image_mask = mask

                image_to_process = output_images[n].crop(crop_region)
                images_to_process.append(image_to_process)

            p.init_images = images_to_process

            latent_mask = Image.new("RGB", (p.width, p.height), "white")
            draw = ImageDraw.Draw(latent_mask)
            draw.rectangle((
@@ -232,31 +232,52 @@ class Script(scripts.Script):
            p.latent_mask = latent_mask

            proc = process_images(p)
            proc_img = proc.images[0]

            if initial_seed_and_info[0] is None:
                initial_seed_and_info[0] = proc.seed
                initial_seed_and_info[1] = proc.info

            out.paste(proc_img, (0 if is_left else out.width - proc_img.width, 0 if is_top else out.height - proc_img.height))
            out = out.crop((0, 0, res_w, res_h))
            return out
            for n in range(count):
                output_images[n].paste(proc.images[n], (0 if is_left else output_images[n].width - proc.images[n].width, 0 if is_top else output_images[n].height - proc.images[n].height))
                output_images[n] = output_images[n].crop((0, 0, res_w, res_h))

            return output_images

        img = init_image
        batch_count = p.n_iter
        batch_size = p.batch_size
        p.n_iter = 1
        state.job_count = batch_count * batch_size * ((1 if left > 0 else 0) + (1 if right > 0 else 0) + (1 if up > 0 else 0) + (1 if down > 0 else 0))
        all_processed_images = []

        for i in range(batch_count):
            imgs = [init_img] * batch_size
            state.job = f"Batch {i + 1} out of {batch_count}"

            if left > 0:
            img = expand(img, left, is_left=True)
                imgs = expand(imgs, batch_size, left, is_left=True)
            if right > 0:
            img = expand(img, right, is_right=True)
                imgs = expand(imgs, batch_size, right, is_right=True)
            if up > 0:
            img = expand(img, up, is_top=True)
                imgs = expand(imgs, batch_size, up, is_top=True)
            if down > 0:
            img = expand(img, down, is_bottom=True)
                imgs = expand(imgs, batch_size, down, is_bottom=True)

            all_processed_images += imgs

        res = Processed(p, [img], initial_seed_and_info[0], initial_seed_and_info[1])
        all_images = all_processed_images

        combined_grid_image = images.image_grid(all_processed_images)
        unwanted_grid_because_of_img_count = len(all_processed_images) < 2 and opts.grid_only_if_multiple
        if opts.return_grid and not unwanted_grid_because_of_img_count:
            all_images = [combined_grid_image] + all_processed_images

        res = Processed(p, all_images, initial_seed_and_info[0], initial_seed_and_info[1])

        if opts.samples_save:
            for img in all_processed_images:
                images.save_image(img, p.outpath_samples, "", res.seed, p.prompt, opts.grid_format, info=res.info, p=p)

        return res
        if opts.grid_save and not unwanted_grid_because_of_img_count:
            images.save_image(combined_grid_image, p.outpath_grids, "grid", res.seed, p.prompt, opts.grid_format, info=res.info, short_filename=not opts.grid_extended_filename, grid=True, p=p)

        return res