Commit 7c9c19b2 authored by catboxanon's avatar catboxanon
Browse files

Refactor postprocessing to use generator to resolve OOM issues

parent ae6b3090
Loading
Loading
Loading
Loading
+31 −32
Original line number Diff line number Diff line
@@ -11,10 +11,9 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,

    shared.state.begin(job="extras")

    image_data = []
    image_names = []
    outputs = []

    def get_images(extras_mode, image, image_folder, input_dir):
        if extras_mode == 1:
            for img in image_folder:
                if isinstance(img, Image.Image):
@@ -23,8 +22,7 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
                else:
                    image = Image.open(os.path.abspath(img.name))
                    fn = os.path.splitext(img.orig_name)[0]
            image_data.append(image)
            image_names.append(fn)
                yield image, fn
        elif extras_mode == 2:
            assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
            assert input_dir, 'input directory not selected'
@@ -35,13 +33,10 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
                    image = Image.open(filename)
                except Exception:
                    continue
            image_data.append(image)
            image_names.append(filename)
                yield image, filename
        else:
            assert image, 'image not selected'

        image_data.append(image)
        image_names.append(None)
            yield image, None

    if extras_mode == 2 and output_dir != '':
        outpath = output_dir
@@ -50,14 +45,16 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,

    infotext = ''

    for image, name in zip(image_data, image_names):
    for image_data, name in get_images(extras_mode, image, image_folder, input_dir):
        image_data: Image.Image

        shared.state.textinfo = name

        parameters, existing_pnginfo = images.read_info_from_image(image)
        parameters, existing_pnginfo = images.read_info_from_image(image_data)
        if parameters:
            existing_pnginfo["parameters"] = parameters

        pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))
        pp = scripts_postprocessing.PostprocessedImage(image_data.convert("RGB"))

        scripts.scripts_postproc.run(pp, args)

@@ -78,6 +75,8 @@ def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir,
        if extras_mode != 2 or show_extras_results:
            outputs.append(pp.image)

        image_data.close()

    devices.torch_gc()

    return outputs, ui_common.plaintext_to_html(infotext), ''