Commit ec877472 authored by EllangoK's avatar EllangoK
Browse files

swaps xyz axes internally if one costs more

parent e46bfa5a
Loading
Loading
Loading
Loading
+52 −12
Original line number Original line Diff line number Diff line
@@ -205,7 +205,7 @@ axis_options = [
]
]




def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, swap_axes_processing_order):
def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend, include_lone_images, include_sub_grids, first_axes_processed, second_axes_processed):
    hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
    hor_texts = [[images.GridAnnotation(x)] for x in x_labels]
    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
    ver_texts = [[images.GridAnnotation(y)] for y in y_labels]
    title_texts = [[images.GridAnnotation(z)] for z in z_labels]
    title_texts = [[images.GridAnnotation(z)] for z in z_labels]
@@ -251,16 +251,36 @@ def draw_xyz_grid(p, xs, ys, zs, x_labels, y_labels, z_labels, cell, draw_legend
        except:
        except:
            image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)
            image_cache[index(ix, iy, iz)] = Image.new(cell_mode, cell_size)


    if swap_axes_processing_order:
    if first_axes_processed == 'x':
        for ix, x in enumerate(xs):
        for ix, x in enumerate(xs):
            if second_axes_processed == 'y':
                for iy, y in enumerate(ys):
                for iy, y in enumerate(ys):
                for iy, y in enumerate(zs):
                    for iz, z in enumerate(zs):
                        process_cell(x, y, z, ix, iy, iz)
                        process_cell(x, y, z, ix, iy, iz)
            else:
            else:
                for iz, z in enumerate(zs):
                    for iy, y in enumerate(ys):
                    for iy, y in enumerate(ys):
                        process_cell(x, y, z, ix, iy, iz)
    elif first_axes_processed == 'y':
        for iy, y in enumerate(ys):
            if second_axes_processed == 'x':
                for ix, x in enumerate(xs):
                for ix, x in enumerate(xs):
                    for iz, z in enumerate(zs):
                    for iz, z in enumerate(zs):
                        process_cell(x, y, z, ix, iy, iz)
                        process_cell(x, y, z, ix, iy, iz)
            else:
                for iz, z in enumerate(zs):
                    for ix, x in enumerate(xs):
                        process_cell(x, y, z, ix, iy, iz)
    elif first_axes_processed == 'z':
        for iz, z in enumerate(zs):
            if second_axes_processed == 'x':
                for ix, x in enumerate(xs):
                    for iy, y in enumerate(ys):
                        process_cell(x, y, z, ix, iy, iz)
            else:
                for iy, y in enumerate(ys):
                    for ix, x in enumerate(xs):
                        process_cell(x, y, z, ix, iy, iz)


    if not processed_result:
    if not processed_result:
        print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
        print("Unexpected error: draw_xyz_grid failed to return even a single processed image")
@@ -322,7 +342,7 @@ class Script(scripts.Script):
                with gr.Row():
                with gr.Row():
                    y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
                    y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
                    y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
                    y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
                    fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyzz_grid_fill_y_tool_button", visible=False)
                    fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)


                with gr.Row():
                with gr.Row():
                    z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
                    z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
@@ -487,7 +507,26 @@ class Script(scripts.Script):
        # If one of the axes is very slow to change between (like SD model
        # If one of the axes is very slow to change between (like SD model
        # checkpoint), then make sure it is in the outer iteration of the nested
        # checkpoint), then make sure it is in the outer iteration of the nested
        # `for` loop.
        # `for` loop.
        swap_axes_processing_order = x_opt.cost > y_opt.cost
        first_axes_processed = 'x'
        second_axes_processed = 'y'
        if x_opt.cost > y_opt.cost and x_opt.cost > z_opt.cost:
            first_axes_processed = 'x'
            if y_opt.cost > z_opt.cost:
                second_axes_processed = 'y'
            else:
                second_axes_processed = 'z'
        elif y_opt.cost > x_opt.cost and y_opt.cost > z_opt.cost:
            first_axes_processed = 'y'
            if x_opt.cost > z_opt.cost:
                second_axes_processed = 'x'
            else:
                second_axes_processed = 'z'
        elif z_opt.cost > x_opt.cost and z_opt.cost > y_opt.cost:
            first_axes_processed = 'z'
            if x_opt.cost > y_opt.cost:
                second_axes_processed = 'x'
            else:
                second_axes_processed = 'y'


        def cell(x, y, z):
        def cell(x, y, z):
            if shared.state.interrupted:
            if shared.state.interrupted:
@@ -538,7 +577,8 @@ class Script(scripts.Script):
                draw_legend=draw_legend,
                draw_legend=draw_legend,
                include_lone_images=include_lone_images,
                include_lone_images=include_lone_images,
                include_sub_grids=include_sub_grids,
                include_sub_grids=include_sub_grids,
                swap_axes_processing_order=swap_axes_processing_order
                first_axes_processed=first_axes_processed,
                second_axes_processed=second_axes_processed
            )
            )


        if opts.grid_save:
        if opts.grid_save: