Unverified Commit 2d9635cc authored by DejitaruJin's avatar DejitaruJin Committed by GitHub
Browse files

Fix display and save order for X/Y/Z Grid script

parent 0cc0ee1b
Loading
Loading
Loading
Loading
+73 −60
Original line number Diff line number Diff line
@@ -25,8 +25,6 @@ from modules.ui_components import ToolButton

fill_values_symbol = "\U0001f4d2"  # 📒

AxisInfo = namedtuple('AxisInfo', ['axis', 'values'])


def apply_field(field):
    def fun(p, x, xs):
@@ -188,7 +186,6 @@ axis_options = [
    AxisOption("Steps", int, apply_field("steps")),
    AxisOptionTxt2Img("Hires steps", int, apply_field("hr_second_pass_steps")),
    AxisOption("CFG Scale", float, apply_field("cfg_scale")),
    AxisOptionImg2Img("Image CFG Scale", float, apply_field("image_cfg_scale")),
    AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
    AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
    AxisOptionTxt2Img("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
@@ -213,49 +210,47 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
    title_texts = [[images.GridAnnotation(z)] for z in z_labels]

    # Temporary list of all the images that are generated to be populated into the grid.
    # Will be filled with empty images for any individual step that fails to process properly
    image_cache = [None] * (len(xs) * len(ys) * len(zs))
    list_size = (len(xs) * len(ys) * len(zs))

    processed_result = None
    cell_mode = "P"
    cell_size = (1, 1)

    state.job_count = len(xs) * len(ys) * len(zs) * p.n_iter
    state.job_count = list_size * p.n_iter

    def process_cell(x, y, z, ix, iy, iz):
        nonlocal image_cache, processed_result, cell_mode, cell_size
        nonlocal processed_result

        def index(ix, iy, iz):
            return ix + iy * len(xs) + iz * len(xs) * len(ys)

        state.job = f"{index(ix, iy, iz) + 1} out of {len(xs) * len(ys) * len(zs)}"
        state.job = f"{index(ix, iy, iz) + 1} out of {list_size}"

        processed: Processed = cell(x, y, z)

        try:
            # this dereference will throw an exception if the image was not processed
            # (this happens in cases such as if the user stops the process from the UI)
            processed_image = processed.images[0]

        if processed_result is None:
                # Use our first valid processed result as a template container to hold our full results
            # Use our first processed result object as a template container to hold our full results
            processed_result = copy(processed)
                cell_mode = processed_image.mode
                cell_size = processed_image.size
                processed_result.images = [Image.new(cell_mode, cell_size)]
                processed_result.all_prompts = [processed.prompt]
                processed_result.all_seeds = [processed.seed]
                processed_result.infotexts = [processed.infotexts[0]]

            image_cache[index(ix, iy, iz)] = processed_image
            if include_lone_images:
                processed_result.images.append(processed_image)
                processed_result.all_prompts.append(processed.prompt)
                processed_result.all_seeds.append(processed.seed)
                processed_result.infotexts.append(processed.infotexts[0])
        except:
            image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
            processed_result.images = [None] * list_size
            processed_result.all_prompts = [None] * list_size
            processed_result.all_seeds = [None] * list_size
            processed_result.infotexts = [None] * list_size
            processed_result.index_of_first_image = 0

        idx = index(ix, iy, iz)
        if processed.images:
            # Non-empty list indicates some degree of success.
            processed_result.images[idx] = processed.images[0]
            processed_result.all_prompts[idx] = processed.prompt
            processed_result.all_seeds[idx] = processed.seed
            processed_result.infotexts[idx] = processed.infotexts[0]
        else:
            cell_mode = "P"
            cell_size = (processed_result.width, processed_result.height)
            if processed_result.images[0] is not None:
                cell_mode = processed_result.images[0].mode
                #This corrects size in case of batches:
                cell_size = processed_result.images[0].size
            processed_result.images[idx] = Image.new(cell_mode, cell_size)


    if first_axes_processed == 'x':
        for ix, x in enumerate(xs):
@@ -289,27 +284,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
                        process_cell(x, y, z, ix, iy, iz)

    if not processed_result:
        # Should never happen, I've only seen it on one of four open tabs and it needed to refresh.
        print("Unexpected error: Processing could not begin, you may need to refresh the tab or restart the service.")
        return Processed(p, [])
    elif not any(processed_result.images):
        print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
        return Processed(p, [])

    sub_grids = [None] * len(zs)
    for i in range(len(zs)):
        start_index = i * len(xs) * len(ys)
    z_count = len(zs)
    sub_grids = [None] * z_count
    for i in range(z_count):
        start_index = (i * len(xs) * len(ys)) + i
        end_index = start_index + len(xs) * len(ys)
        grid = images.image_grid(image_cache[start_index:end_index], rows=len(ys))
        grid = images.image_grid(processed_result.images[start_index:end_index], rows=len(ys))
        if draw_legend:
            grid = images.draw_grid_annotations(grid, cell_size[0], cell_size[1], hor_texts, ver_texts, margin_size)
        sub_grids[i] = grid
        if include_sub_grids and len(zs) > 1:
            processed_result.images.insert(i+1, grid)

    sub_grid_size = sub_grids[0].size
    z_grid = images.image_grid(sub_grids, rows=1)
            grid = images.draw_grid_annotations(grid, processed_result.images[start_index].size[0], processed_result.images[start_index].size[1], hor_texts, ver_texts, margin_size)
        processed_result.images.insert(i, grid)
        processed_result.all_prompts.insert(i, processed_result.all_prompts[start_index])
        processed_result.all_seeds.insert(i, processed_result.all_seeds[start_index])
        processed_result.infotexts.insert(i, processed_result.infotexts[start_index])

    sub_grid_size = processed_result.images[0].size
    z_grid = images.image_grid(processed_result.images[:z_count], rows=1)
    if draw_legend:
        z_grid = images.draw_grid_annotations(z_grid, sub_grid_size[0], sub_grid_size[1], title_texts, [[images.GridAnnotation()]])
    processed_result.images[0] = z_grid
    processed_result.images.insert(0, z_grid)
    processed_result.all_prompts.insert(0, processed_result.all_prompts[0])
    processed_result.all_seeds.insert(0, processed_result.all_seeds[0])
    processed_result.infotexts.insert(0, processed_result.infotexts[0])

    return processed_result, sub_grids
    return processed_result


class SharedSettingsStackHelper(object):
@@ -364,7 +368,7 @@ class Script(scripts.Script):
                include_lone_images = gr.Checkbox(label='Include Sub Images', value=False, elem_id=self.elem_id("include_lone_images"))
                include_sub_grids = gr.Checkbox(label='Include Sub Grids', value=False, elem_id=self.elem_id("include_sub_grids"))
            with gr.Column():
                margin_size = gr.Slider(label="Grid margins (px)", minimum=0, maximum=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
                margin_size = gr.Slider(label="Grid margins (px)", min=0, max=500, value=0, step=2, elem_id=self.elem_id("margin_size"))
        
        with gr.Row(variant="compact", elem_id="swap_axes"):
            swap_xy_axes_button = gr.Button(value="Swap X/Y axes", elem_id="xy_grid_swap_axes_button")
@@ -526,14 +530,10 @@ class Script(scripts.Script):

        grid_infotext = [None]

        state.xyz_plot_x = AxisInfo(x_opt, xs)
        state.xyz_plot_y = AxisInfo(y_opt, ys)
        state.xyz_plot_z = AxisInfo(z_opt, zs)

        # If one of the axes is very slow to change between (like SD model
        # checkpoint), then make sure it is in the outer iteration of the nested
        # `for` loop.
        first_axes_processed = 'x'
        first_axes_processed = 'z'
        second_axes_processed = 'y'
        if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
            first_axes_processed = 'x'
@@ -593,7 +593,7 @@ class Script(scripts.Script):
            return res

        with SharedSettingsStackHelper():
            processed, sub_grids = draw_xyz_grid(
            processed = draw_xyz_grid(
                p,
                xs=xs,
                ys=ys,
@@ -610,11 +610,24 @@ class Script(scripts.Script):
                margin_size=margin_size
            )

        if opts.grid_save and len(sub_grids) > 1:
            for sub_grid in sub_grids:
                images.save_image(sub_grid, p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)

        if opts.grid_save:
            images.save_image(processed.images[0], p.outpath_grids, "xyz_grid", info=grid_infotext[0], extension=opts.grid_format, prompt=p.prompt, seed=processed.seed, grid=True, p=p)
        z_count = len(zs)

        if not include_lone_images:
            # Don't need sub-images anymore, drop from list:
            processed.images = processed.images[:z_count+1]

        if opts.grid_save and processed.images:
            # Auto-save main and sub-grids:
            grid_count = z_count + 1 if z_count > 1 else 1
            for g in range(grid_count):
                images.save_image(processed.images[g], p.outpath_grids, "xyz_grid", info=processed.infotexts[g], extension=opts.grid_format, prompt=processed.all_prompts[g], seed=processed.all_seeds[g], grid=True, p=processed)

        if not include_sub_grids:
            # Done with sub-grids, drop all related information:
            for sg in range(z_count):
                del processed.images[1]
                del processed.all_prompts[1]
                del processed.all_seeds[1]
                del processed.infotexts[1]

        return processed