Commit 55947857 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add a button for XY Plot to fill in available values for axes that support this

parent d073637e
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -20,6 +20,7 @@ titles = {
    "\u{1f4be}": "Save style",
    "\u{1f4be}": "Save style",
    "\U0001F5D1": "Clear prompt",
    "\U0001F5D1": "Clear prompt",
    "\u{1f4cb}": "Apply selected styles to current prompt",
    "\u{1f4cb}": "Apply selected styles to current prompt",
    "\u{1f4d2}": "Paste available values into the field",


    "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
    "Inpaint a part of image": "Draw a mask over an image, and the script will regenerate the masked area with content according to prompt",
    "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
    "SD upscale": "Upscale image normally, split result into tiles, improve each tile using img2img, merge whole image back",
+66 −35
Original line number Original line Diff line number Diff line
@@ -10,7 +10,7 @@ import numpy as np
import modules.scripts as scripts
import modules.scripts as scripts
import gradio as gr
import gradio as gr


from modules import images, paths, sd_samplers, processing
from modules import images, paths, sd_samplers, processing, sd_models, sd_vae
from modules.hypernetworks import hypernetwork
from modules.hypernetworks import hypernetwork
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.processing import process_images, Processed, StableDiffusionProcessingTxt2Img
from modules.shared import opts, cmd_opts, state
from modules.shared import opts, cmd_opts, state
@@ -22,8 +22,9 @@ import glob
import os
import os
import re
import re


from modules.ui_components import ToolButton


up_down_arrow_symbol = "\u2195\ufe0f"
fill_values_symbol = "\U0001f4d2"  # 📒




def apply_field(field):
def apply_field(field):
@@ -178,34 +179,49 @@ def str_permutations(x):
    """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
    """dummy function for specifying it in AxisOption's type when you want to get a list of permutations"""
    return x
    return x


AxisOption = namedtuple("AxisOption", ["label", "type", "apply", "format_value", "confirm", "cost"])

AxisOptionImg2Img = namedtuple("AxisOptionImg2Img", ["label", "type", "apply", "format_value", "confirm", "cost"])
class AxisOption:
    def __init__(self, label, type, apply, format_value=format_value_add_label, confirm=None, cost=0.0, choices=None):
        self.label = label
        self.type = type
        self.apply = apply
        self.format_value = format_value
        self.confirm = confirm
        self.cost = cost
        self.choices = choices
        self.is_img2img = False


class AxisOptionImg2Img(AxisOption):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.is_img2img = False




axis_options = [
axis_options = [
    AxisOption("Nothing", str, do_nothing, format_nothing, None, 0),
    AxisOption("Nothing", str, do_nothing, format_value=format_nothing),
    AxisOption("Seed", int, apply_field("seed"), format_value_add_label, None, 0),
    AxisOption("Seed", int, apply_field("seed")),
    AxisOption("Var. seed", int, apply_field("subseed"), format_value_add_label, None, 0),
    AxisOption("Var. seed", int, apply_field("subseed")),
    AxisOption("Var. strength", float, apply_field("subseed_strength"), format_value_add_label, None, 0),
    AxisOption("Var. strength", float, apply_field("subseed_strength")),
    AxisOption("Steps", int, apply_field("steps"), format_value_add_label, None, 0),
    AxisOption("Steps", int, apply_field("steps")),
    AxisOption("CFG Scale", float, apply_field("cfg_scale"), format_value_add_label, None, 0),
    AxisOption("CFG Scale", float, apply_field("cfg_scale")),
    AxisOption("Prompt S/R", str, apply_prompt, format_value, None, 0),
    AxisOption("Prompt S/R", str, apply_prompt, format_value=format_value),
    AxisOption("Prompt order", str_permutations, apply_order, format_value_join_list, None, 0),
    AxisOption("Prompt order", str_permutations, apply_order, format_value=format_value_join_list),
    AxisOption("Sampler", str, apply_sampler, format_value, confirm_samplers, 0),
    AxisOption("Sampler", str, apply_sampler, format_value=format_value, confirm=confirm_samplers, choices=lambda: [x.name for x in sd_samplers.samplers]),
    AxisOption("Checkpoint name", str, apply_checkpoint, format_value, confirm_checkpoints, 1.0),
    AxisOption("Checkpoint name", str, apply_checkpoint, format_value=format_value, confirm=confirm_checkpoints, cost=1.0, choices=lambda: list(sd_models.checkpoints_list)),
    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value, confirm_hypernetworks, 0.2),
    AxisOption("Hypernetwork", str, apply_hypernetwork, format_value=format_value, confirm=confirm_hypernetworks, cost=0.2, choices=lambda: list(shared.hypernetworks)),
    AxisOption("Hypernet str.", float, apply_hypernetwork_strength, format_value_add_label, None, 0),
    AxisOption("Hypernet str.", float, apply_hypernetwork_strength),
    AxisOption("Sigma Churn", float, apply_field("s_churn"), format_value_add_label, None, 0),
    AxisOption("Sigma Churn", float, apply_field("s_churn")),
    AxisOption("Sigma min", float, apply_field("s_tmin"), format_value_add_label, None, 0),
    AxisOption("Sigma min", float, apply_field("s_tmin")),
    AxisOption("Sigma max", float, apply_field("s_tmax"), format_value_add_label, None, 0),
    AxisOption("Sigma max", float, apply_field("s_tmax")),
    AxisOption("Sigma noise", float, apply_field("s_noise"), format_value_add_label, None, 0),
    AxisOption("Sigma noise", float, apply_field("s_noise")),
    AxisOption("Eta", float, apply_field("eta"), format_value_add_label, None, 0),
    AxisOption("Eta", float, apply_field("eta")),
    AxisOption("Clip skip", int, apply_clip_skip, format_value_add_label, None, 0),
    AxisOption("Clip skip", int, apply_clip_skip),
    AxisOption("Denoising", float, apply_field("denoising_strength"), format_value_add_label, None, 0),
    AxisOption("Denoising", float, apply_field("denoising_strength")),
    AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), format_value_add_label, None, 0),
    AxisOption("Hires upscaler", str, apply_field("hr_upscaler"), choices=lambda: [x.name for x in shared.sd_upscalers]),
    AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight"), format_value_add_label, None, 0),
    AxisOption("Cond. Image Mask Weight", float, apply_field("inpainting_mask_weight")),
    AxisOption("VAE", str, apply_vae, format_value_add_label, None, 0.7),
    AxisOption("VAE", str, apply_vae, cost=0.7, choices=lambda: list(sd_vae.vae_dict)),
    AxisOption("Styles", str, apply_styles, format_value_add_label, None, 0),
    AxisOption("Styles", str, apply_styles, choices=lambda: list(shared.prompt_styles.styles)),
]
]




@@ -262,7 +278,7 @@ def draw_xy_grid(p, xs, ys, x_labels, y_labels, cell, draw_legend, include_lone_


    if not processed_result:
    if not processed_result:
        print("Unexpected error: draw_xy_grid failed to return even a single processed image")
        print("Unexpected error: draw_xy_grid failed to return even a single processed image")
        return Processed()
        return Processed(p, [])


    grid = images.image_grid(image_cache, rows=len(ys))
    grid = images.image_grid(image_cache, rows=len(ys))
    if draw_legend:
    if draw_legend:
@@ -302,23 +318,25 @@ class Script(scripts.Script):
        return "X/Y plot"
        return "X/Y plot"


    def ui(self, is_img2img):
    def ui(self, is_img2img):
        current_axis_options = [x for x in axis_options if type(x) == AxisOption or type(x) == AxisOptionImg2Img and is_img2img]
        current_axis_options = [x for x in axis_options if type(x) == AxisOption or x.is_img2img and is_img2img]


        with gr.Row():
        with gr.Row():
            with gr.Column(scale=1, elem_id="xy_grid_button_column"):
                swap_axes_button = gr.Button(value=up_down_arrow_symbol, elem_id="xy_grid_swap_axes")
            with gr.Column(scale=19):
            with gr.Column(scale=19):
                with gr.Row():
                with gr.Row():
                    x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
                    x_type = gr.Dropdown(label="X type", choices=[x.label for x in current_axis_options], value=current_axis_options[1].label, type="index", elem_id=self.elem_id("x_type"))
                    x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
                    x_values = gr.Textbox(label="X values", lines=1, elem_id=self.elem_id("x_values"))
                    fill_x_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_x_tool_button", visible=False)


                with gr.Row():
                with gr.Row():
                    y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
                    y_type = gr.Dropdown(label="Y type", choices=[x.label for x in current_axis_options], value=current_axis_options[0].label, type="index", elem_id=self.elem_id("y_type"))
                    y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
                    y_values = gr.Textbox(label="Y values", lines=1, elem_id=self.elem_id("y_values"))
                    fill_y_button = ToolButton(value=fill_values_symbol, elem_id="xy_grid_fill_y_tool_button", visible=False)


        with gr.Row(variant="compact"):
            draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
            draw_legend = gr.Checkbox(label='Draw legend', value=True, elem_id=self.elem_id("draw_legend"))
            include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
            include_lone_images = gr.Checkbox(label='Include Separate Images', value=False, elem_id=self.elem_id("include_lone_images"))
            no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
            no_fixed_seeds = gr.Checkbox(label='Keep -1 for seeds', value=False, elem_id=self.elem_id("no_fixed_seeds"))
            swap_axes_button = gr.Button(value="Swap axes", elem_id="xy_grid_swap_axes_button")


        def swap_axes(x_type, x_values, y_type, y_values):
        def swap_axes(x_type, x_values, y_type, y_values):
            nonlocal current_axis_options
            nonlocal current_axis_options
@@ -327,6 +345,19 @@ class Script(scripts.Script):
        swap_args = [x_type, x_values, y_type, y_values]
        swap_args = [x_type, x_values, y_type, y_values]
        swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)
        swap_axes_button.click(swap_axes, inputs=swap_args, outputs=swap_args)


        def fill(x_type):
            axis = axis_options[x_type]
            return ", ".join(axis.choices()) if axis.choices else gr.update()

        fill_x_button.click(fn=fill, inputs=[x_type], outputs=[x_values])
        fill_y_button.click(fn=fill, inputs=[y_type], outputs=[y_values])

        def select_axis(x_type):
            return gr.Button.update(visible=axis_options[x_type].choices is not None)

        x_type.change(fn=select_axis, inputs=[x_type], outputs=[fill_x_button])
        y_type.change(fn=select_axis, inputs=[y_type], outputs=[fill_y_button])

        return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]
        return [x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds]


    def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
    def run(self, p, x_type, x_values, y_type, y_values, draw_legend, include_lone_images, no_fixed_seeds):
+1 −11
Original line number Original line Diff line number Diff line
@@ -644,7 +644,7 @@ canvas[key="mask"] {
    max-width: 2.5em;
    max-width: 2.5em;
    min-width: 2.5em !important;
    min-width: 2.5em !important;
    height: 2.4em;
    height: 2.4em;
    margin: 0.55em 0;
    margin: 0.55em 0.7em 0.55em 0;
}
}


#quicksettings .gr-button-tool{
#quicksettings .gr-button-tool{
@@ -717,16 +717,6 @@ footer {
    line-height: 2.4em;
    line-height: 2.4em;
}
}


#xy_grid_button_column {
    min-width: 38px !important;
}

#xy_grid_button_column button {
    height: 100%;
    margin-bottom: 0.7em;
    margin-left: 1em;
}

/* The following handles localization for right-to-left (RTL) languages like Arabic.
/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running
If you change anything above, you need to make sure it is RTL compliant by just running