Unverified Commit d073637e authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #6803 from space-nuko/xy-grid-performance-improvement

Optimize XY grid to run slower axes fewer times
parents 064983c0 029260b4
Loading
Loading
Loading
Loading
+70 −53
Original line number Original line Diff line number Diff line
@@ -178,44 +178,44 @@ def str_permutations(x):
    """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
    """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
    return x
    return x


AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm"])
AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm"])
AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])




axis_options = [
axis_options = [
    AxisOption("Nothing", str, do_nothing, format_nothing, None),
    AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
    AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None),
    AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
    AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None),
    AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
    AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None),
    AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
    AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None),
    AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
    AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None),
    AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
    AxisOption("Prompt S/R", str, apply_prompt, format_value, None),
    AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
    AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None),
    AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
    AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers),
    AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
    AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints),
    AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks),
    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
    AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None),
    AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
    AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None),
    AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
    AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None),
    AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
    AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None),
    AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
    AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None),
    AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
    AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None),
    AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
    AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None),
    AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
    AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None),
    AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
    AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None),
    AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
    AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None),
    AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
    AxisOption("VAE", str, apply_vae, format_value_add_label, None),
    AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
    AxisOption("Styles", str, apply_styles, format_value_add_label, None),
    AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
]
]




def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images):
def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_images, swap_axes_processing_order):
    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
    hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
    hor_texts = [[images.GridAnnotation(x)] for x in x_labels]


    # Temporary list of all the images that are generated to be populated into the grid.
    # Temporary list of all the images that are generated to be populated into the grid.
    # Will be filled with empty images for any individual step that fails to process properly
    # Will be filled with empty images for any individual step that fails to process properly
    image_cache = []
    image_cache = [None] * (len(xs) * len(ys))


    processed_result = None
    processed_result = None
    cell_mode = "P"
    cell_mode = "P"
@@ -223,11 +223,13 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_


    state.job_count = len(xs) * len(ys) * p.n_iter
    state.job_count = len(xs) * len(ys) * p.n_iter


    for iy, y in enumerate(ys):
    def process_cell(x, y, ix, iy):
        for ix, x in enumerate(xs):
        nonlocal image_cache, processed_result, cell_mode, cell_size

        state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"
        state.job = f"{ix + iy * len(xs) + 1} out of {len(xs) * len(ys)}"


        processed: Processed = cell(x, y)
        processed: Processed = cell(x, y)

        try:
        try:
            # this dereference will throw an exception if the image was not processed
            # this dereference will throw an exception if the image was not processed
            # (this happens in cases such as if the user stops the process from the UI)
            # (this happens in cases such as if the user stops the process from the UI)
@@ -240,14 +242,23 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_
                cell_size = processed_image.size
                cell_size = processed_image.size
                processed_result.images = [Image.new(cell_mode, cell_size)]
                processed_result.images = [Image.new(cell_mode, cell_size)]


                image_cache.append(processed_image)
            image_cache[ix + iy * len(xs)] = processed_image
            if include_lone_images:
            if include_lone_images:
                processed_result.images.append(processed_image)
                processed_result.images.append(processed_image)
                processed_result.all_prompts.append(processed.prompt)
                processed_result.all_prompts.append(processed.prompt)
                processed_result.all_seeds.append(processed.seed)
                processed_result.all_seeds.append(processed.seed)
                processed_result.infotexts.append(processed.infotexts[0])
                processed_result.infotexts.append(processed.infotexts[0])
        except:
        except:
                image_cache.append(Image.new(cell_mode, cell_size))
            image_cache[ix + iy * len(xs)] = Image.new(cell_mode, cell_size)

    if swap_axes_processing_order:
        for ix, x in enumerate(xs):
            for iy, y in enumerate(ys):
                process_cell(x, y, ix, iy)
    else:
        for iy, y in enumerate(ys):
            for ix, x in enumerate(xs):
                process_cell(x, y, ix, iy)


    if not processed_result:
    if not processed_result:
        print("Unexpected error: draw_xy_grid failed to return even a single processed image")
        print("Unexpected error: draw_xy_grid failed to return even a single processed image")
@@ -417,6 +428,11 @@ class Script(scripts.Script):


        grid_infotext = [None]
        grid_infotext = [None]


        # If one of the axes is very slow to change between (like SD model
        # checkpoint), then make sure it is in the outer iteration of the nested
        # `for` loop.
        swap_axes_processing_order = x_opt.cost > y_opt.cost

        def cell(x, y):
        def cell(x, y):
            if shared.state.interrupted:
            if shared.state.interrupted:
                return Processed(p, [], p.seed, "")
                return Processed(p, [], p.seed, "")
@@ -455,7 +471,8 @@ class Script(scripts.Script):
                y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
                y_labels=[y_opt.format_value(p, y_opt, y) for y in ys],
                cell=cell,
                cell=cell,
                draw_legend=draw_legend,
                draw_legend=draw_legend,
                include_lone_images=include_lone_images
                include_lone_images=include_lone_images,
                swap_axes_processing_order=swap_axes_processing_order
            )
            )


        if opts.grid_save:
        if opts.grid_save: