Commit c01dc1cb authored by pangbo13's avatar pangbo13
Browse files

add dropdown for X/Y/Z plot

parent 22bcc7be
Loading
Loading
Loading
Loading
+23 −15
Original line number Diff line number Diff line
@@ -374,16 +374,19 @@ class Script(scripts.Script):
                with gr.Row():
                    x_type = gr.Dropdown(label="X type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
                    x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
                    x_values_dropdown = gr.Dropdown(label="X values",visible=False,multiselect=True,interactive=True)
                    fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_x_tool_button", visible=False)

                with gr.Row():
                    y_type = gr.Dropdown(label="Y type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
                    y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
                    y_values_dropdown = gr.Dropdown(label="Y values",visible=False,multiselect=True,interactive=True)
                    fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_y_tool_button", visible=False)

                with gr.Row():
                    z_type = gr.Dropdown(label="Z type", choices=[x.label for x in self.current_axis_options], value=self.current_axis_options[0].label, type="index", elem_id=self.elem_id("z_type"))
                    z_values = gr.Textbox(label="Z values", lines=1, elem_id=self.elem_id("z_values"))
                    z_values_dropdown = gr.Dropdown(label="Z values",visible=False,multiselect=True,interactive=True)
                    fill_z_button = ToolButton(value=fill_values_symbol, elem_id="xyz_grid_fill_z_tool_button", visible=False)

        with gr.Row(variant="compact", elem_id="axis_options"):
@@ -413,18 +416,20 @@ class Script(scripts.Script):

        def fill(x_type):
            axis = self.current_axis_options[x_type]
            return ", ".join(axis.choices()) if axis.choices else gr.update()
            return axis.choices() if axis.choices else gr.update()

        fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
        fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])
        fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values])
        fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values_dropdown])
        fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values_dropdown])
        fill_z_button.click(fn=fill, inputs=[z_type], outputs=[z_values_dropdown])

        def select_axis(x_type):
            return gr.Button.update(visible=self.current_axis_options[x_type].choices is not None)
            choices = self.current_axis_options[x_type].choices
            has_choices = choices is not None
            return gr.Button.update(visible=has_choices),gr.Textbox.update(visible=not has_choices),gr.update(choices=choices() if has_choices else None,visible=has_choices,value=[])

        x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
        y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])
        z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button])
        x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button,x_values,x_values_dropdown])
        y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button,y_values,y_values_dropdown])
        z_type.change(fn=select_axis, inputs=[z_type], outputs=[fill_z_button,z_values,z_values_dropdown])

        self.infotext_fields = (
            (x_type, "X Type"),
@@ -435,19 +440,22 @@ class Script(scripts.Script):
            (z_values, "Z Values"),
        )

        return [x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]
        return [x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size]

    def run(self, p, x_type, x_values, y_type, y_values, z_type, z_values, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
    def run(self, p, x_type, x_values, x_values_dropdown, y_type, y_values, y_values_dropdown, z_type, z_values, z_values_dropdown, draw_legend, include_lone_images, include_sub_grids, no_fixed_seeds, margin_size):
        if not no_fixed_seeds:
            modules.processing.fix_seed(p)

        if not opts.return_grid:
            p.batch_size = 1

        def process_axis(opt, vals):
        def process_axis(opt, vals, vals_dropdown):
            if opt.label == 'Nothing':
                return [0]

            if opt.choices is not None:
                valslist = vals_dropdown
            else:
                valslist = [x.strip() for x in chain.from_iterable(csv.reader(StringIO(vals))) if x]

            if opt.type == int:
@@ -506,13 +514,13 @@ class Script(scripts.Script):
            return valslist

        x_opt = self.current_axis_options[x_type]
        xs = process_axis(x_opt, x_values)
        xs = process_axis(x_opt, x_values, x_values_dropdown)

        y_opt = self.current_axis_options[y_type]
        ys = process_axis(y_opt, y_values)
        ys = process_axis(y_opt, y_values, y_values_dropdown)

        z_opt = self.current_axis_options[z_type]
        zs = process_axis(z_opt, z_values)
        zs = process_axis(z_opt, z_values, z_values_dropdown)

        # this could be moved to common code, but unlikely to be ever triggered anywhere else
        Image.MAX_IMAGE_PIXELS = None # disable check in Pillow and rely on check below to allow large custom image sizes