Commit 20549a50 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add style editor dialog

rework toprow for img2img and txt2img to use a class with fields
fix the console error when editing checkpoint user metadata
parent 8e840e15
Loading
Loading
Loading
Loading
+1 −1
Original line number Original line Diff line number Diff line
@@ -68,7 +68,7 @@ class CheckpointInfo:


        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'
        self.title = name if self.shorthash is None else f'{name} [{self.shorthash}]'


        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])
        self.ids = [self.hash, self.model_name, self.title, name, self.name_for_extra, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])


    def register(self):
    def register(self):
        checkpoints_list[self.title] = self
        checkpoints_list[self.title] = self
+1 −4
Original line number Original line Diff line number Diff line
@@ -106,10 +106,7 @@ class StyleDatabase:
        if os.path.exists(path):
        if os.path.exists(path):
            shutil.copy(path, f"{path}.bak")
            shutil.copy(path, f"{path}.bak")


        fd = os.open(path, os.O_RDWR | os.O_CREAT)
        with open(path, "w", encoding="utf-8-sig", newline='') as file:
        with os.fdopen(fd, "w", encoding="utf-8-sig", newline='') as file:
            # _fields is actually part of the public API: typing.NamedTuple is a replacement for collections.NamedTuple,
            # and collections.NamedTuple has explicit documentation for accessing _fields. Same goes for _asdict()
            writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
            writer = csv.DictWriter(file, fieldnames=PromptStyle._fields)
            writer.writeheader()
            writer.writeheader()
            writer.writerows(style._asdict() for k, style in self.styles.items())
            writer.writerows(style._asdict() for k, style in self.styles.items())
+92 −138
Original line number Original line Diff line number Diff line
@@ -12,7 +12,7 @@ import numpy as np
from PIL import Image, PngImagePlugin  # noqa: F401
from PIL import Image, PngImagePlugin  # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call


from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger
from modules import sd_hijack, sd_models, script_callbacks, ui_extensions, deepbooru, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave, errors, shared_items, ui_settings, timer, sysinfo, ui_checkpoint_merger, ui_prompt_styles
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path
from modules.paths import script_path
from modules.ui_common import create_refresh_button
from modules.ui_common import create_refresh_button
@@ -92,19 +92,6 @@ def send_gradio_gallery_to_image(x):
    return image_from_url_text(x[0])
    return image_from_url_text(x[0])




def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
        return [gr_show() for x in range(4)]

    style = modules.styles.PromptStyle(name, prompt, negative_prompt)
    shared.prompt_styles.styles[style.name] = style
    # Save all loaded prompt styles: this allows us to update the storage format in the future more easily, because we
    # reserialize all styles every time we save them
    shared.prompt_styles.save_styles(shared.styles_filename)

    return [gr.Dropdown.update(visible=True, choices=list(shared.prompt_styles.styles)) for _ in range(2)]


def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
def calc_resolution_hires(enable, width, height, hr_scale, hr_resize_x, hr_resize_y):
    from modules import processing, devices
    from modules import processing, devices


@@ -129,13 +116,6 @@ def resize_from_to_html(width, height, scale_by):
    return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"
    return f"resize: from <span class='resolution'>{width}x{height}</span> to <span class='resolution'>{target_width}x{target_height}</span>"




def apply_styles(prompt, prompt_neg, styles):
    prompt = shared.prompt_styles.apply_styles_to_prompt(prompt, styles)
    prompt_neg = shared.prompt_styles.apply_negative_styles_to_prompt(prompt_neg, styles)

    return [gr.Textbox.update(value=prompt), gr.Textbox.update(value=prompt_neg), gr.Dropdown.update(value=[])]


def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
def process_interrogate(interrogation_function, mode, ii_input_dir, ii_output_dir, *ii_singles):
    if mode in {0, 1, 3, 4}:
    if mode in {0, 1, 3, 4}:
        return [interrogation_function(ii_singles[mode]), None]
        return [interrogation_function(ii_singles[mode]), None]
@@ -267,71 +247,67 @@ def update_token_counter(text, steps):
    return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"
    return f"<span class='gr-box gr-text-input'>{token_count}/{max_length}</span>"




def create_toprow(is_img2img):
class Toprow:
    def __init__(self, is_img2img):
        id_part = "img2img" if is_img2img else "txt2img"
        id_part = "img2img" if is_img2img else "txt2img"
        self.id_part = id_part


        with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
        with gr.Row(elem_id=f"{id_part}_toprow", variant="compact"):
            with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
            with gr.Column(elem_id=f"{id_part}_prompt_container", scale=6):
                with gr.Row():
                with gr.Row():
                    with gr.Column(scale=80):
                    with gr.Column(scale=80):
                        with gr.Row():
                        with gr.Row():
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
                            self.prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, lines=3, placeholder="Prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])


                with gr.Row():
                with gr.Row():
                    with gr.Column(scale=80):
                    with gr.Column(scale=80):
                        with gr.Row():
                        with gr.Row():
                        negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])
                            self.negative_prompt = gr.Textbox(label="Negative prompt", elem_id=f"{id_part}_neg_prompt", show_label=False, lines=3, placeholder="Negative prompt (press Ctrl+Enter or Alt+Enter to generate)", elem_classes=["prompt"])


        button_interrogate = None
            self.button_interrogate = None
        button_deepbooru = None
            self.button_deepbooru = None
            if is_img2img:
            if is_img2img:
                with gr.Column(scale=1, elem_classes="interrogate-col"):
                with gr.Column(scale=1, elem_classes="interrogate-col"):
                button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
                    self.button_interrogate = gr.Button('Interrogate\nCLIP', elem_id="interrogate")
                button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")
                    self.button_deepbooru = gr.Button('Interrogate\nDeepBooru', elem_id="deepbooru")


            with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
            with gr.Column(scale=1, elem_id=f"{id_part}_actions_column"):
                with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
                with gr.Row(elem_id=f"{id_part}_generate_box", elem_classes="generate-box"):
                interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
                    self.interrupt = gr.Button('Interrupt', elem_id=f"{id_part}_interrupt", elem_classes="generate-box-interrupt")
                skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
                    self.skip = gr.Button('Skip', elem_id=f"{id_part}_skip", elem_classes="generate-box-skip")
                submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')
                    self.submit = gr.Button('Generate', elem_id=f"{id_part}_generate", variant='primary')


                skip.click(
                    self.skip.click(
                        fn=lambda: shared.state.skip(),
                        fn=lambda: shared.state.skip(),
                        inputs=[],
                        inputs=[],
                        outputs=[],
                        outputs=[],
                    )
                    )


                interrupt.click(
                    self.interrupt.click(
                        fn=lambda: shared.state.interrupt(),
                        fn=lambda: shared.state.interrupt(),
                        inputs=[],
                        inputs=[],
                        outputs=[],
                        outputs=[],
                    )
                    )


                with gr.Row(elem_id=f"{id_part}_tools"):
                with gr.Row(elem_id=f"{id_part}_tools"):
                paste = ToolButton(value=paste_symbol, elem_id="paste")
                    self.paste = ToolButton(value=paste_symbol, elem_id="paste")
                clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
                    self.clear_prompt_button = ToolButton(value=clear_prompt_symbol, elem_id=f"{id_part}_clear_prompt")
                extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
                    self.extra_networks_button = ToolButton(value=extra_networks_symbol, elem_id=f"{id_part}_extra_networks")
                prompt_style_apply = ToolButton(value=apply_style_symbol, elem_id=f"{id_part}_style_apply")
                    self.restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
                save_style = ToolButton(value=save_style_symbol, elem_id=f"{id_part}_style_create")

                restore_progress_button = ToolButton(value=restore_progress_symbol, elem_id=f"{id_part}_restore_progress", visible=False)
                    self.token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])

                    self.token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
                token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_token_counter", elem_classes=["token-counter"])
                    self.negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])
                token_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
                    self.negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
                negative_token_counter = gr.HTML(value="<span>0/75</span>", elem_id=f"{id_part}_negative_token_counter", elem_classes=["token-counter"])

                negative_token_button = gr.Button(visible=False, elem_id=f"{id_part}_negative_token_button")
                    self.clear_prompt_button.click(

                clear_prompt_button.click(
                        fn=lambda *x: x,
                        fn=lambda *x: x,
                        _js="confirm_clear_prompt",
                        _js="confirm_clear_prompt",
                    inputs=[prompt, negative_prompt],
                        inputs=[self.prompt, self.negative_prompt],
                    outputs=[prompt, negative_prompt],
                        outputs=[self.prompt, self.negative_prompt],
                    )
                    )


            with gr.Row(elem_id=f"{id_part}_styles_row"):
                self.ui_styles = ui_prompt_styles.UiPromptStyles(id_part, self.prompt, self.negative_prompt)
                prompt_styles = gr.Dropdown(label="Styles", elem_id=f"{id_part}_styles", choices=[k for k, v in shared.prompt_styles.styles.items()], value=[], multiselect=True)
                create_refresh_button(prompt_styles, shared.prompt_styles.reload, lambda: {"choices": [k for k, v in shared.prompt_styles.styles.items()]}, f"refresh_{id_part}_styles")

    return prompt, prompt_styles, negative_prompt, submit, button_interrogate, button_deepbooru, prompt_style_apply, save_style, paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button




def setup_progressbar(*args, **kwargs):
def setup_progressbar(*args, **kwargs):
@@ -419,14 +395,14 @@ def create_ui():
    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)
    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)


    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
        txt2img_prompt, txt2img_prompt_styles, txt2img_negative_prompt, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=False)
        toprow = txt2img_toprow = Toprow(is_img2img=False)


        dummy_component = gr.Label(visible=False)
        dummy_component = gr.Label(visible=False)
        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)
        txt_prompt_img = gr.File(label="", elem_id="txt2img_prompt_image", file_count="single", type="binary", visible=False)


        with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
        with FormRow(variant='compact', elem_id="txt2img_extra_networks", visible=False) as extra_networks:
            from modules import ui_extra_networks
            from modules import ui_extra_networks
            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'txt2img')
            extra_networks_ui = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'txt2img')


        with gr.Row().style(equal_height=False):
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='compact', elem_id="txt2img_settings"):
            with gr.Column(variant='compact', elem_id="txt2img_settings"):
@@ -532,9 +508,9 @@ def create_ui():
                _js="submit",
                _js="submit",
                inputs=[
                inputs=[
                    dummy_component,
                    dummy_component,
                    txt2img_prompt,
                    toprow.prompt,
                    txt2img_negative_prompt,
                    toprow.negative_prompt,
                    txt2img_prompt_styles,
                    toprow.ui_styles.dropdown,
                    steps,
                    steps,
                    sampler_index,
                    sampler_index,
                    restore_faces,
                    restore_faces,
@@ -569,12 +545,12 @@ def create_ui():
                show_progress=False,
                show_progress=False,
            )
            )


            txt2img_prompt.submit(**txt2img_args)
            toprow.prompt.submit(**txt2img_args)
            submit.click(**txt2img_args)
            toprow.submit.click(**txt2img_args)


            res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)
            res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('txt2img')}", inputs=None, outputs=None, show_progress=False)


            restore_progress_button.click(
            toprow.restore_progress_button.click(
                fn=progress.restore_progress,
                fn=progress.restore_progress,
                _js="restoreProgressTxt2img",
                _js="restoreProgressTxt2img",
                inputs=[dummy_component],
                inputs=[dummy_component],
@@ -593,7 +569,7 @@ def create_ui():
                    txt_prompt_img
                    txt_prompt_img
                ],
                ],
                outputs=[
                outputs=[
                    txt2img_prompt,
                    toprow.prompt,
                    txt_prompt_img
                    txt_prompt_img
                ],
                ],
                show_progress=False,
                show_progress=False,
@@ -607,8 +583,8 @@ def create_ui():
            )
            )


            txt2img_paste_fields = [
            txt2img_paste_fields = [
                (txt2img_prompt, "Prompt"),
                (toprow.prompt, "Prompt"),
                (txt2img_negative_prompt, "Negative prompt"),
                (toprow.negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (restore_faces, "Face restoration"),
@@ -621,7 +597,7 @@ def create_ui():
                (subseed_strength, "Variation seed strength"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (seed_resize_from_h, "Seed resize from-2"),
                (txt2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                (denoising_strength, "Denoising strength"),
                (denoising_strength, "Denoising strength"),
                (enable_hr, lambda d: "Denoising strength" in d),
                (enable_hr, lambda d: "Denoising strength" in d),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
                (hr_options, lambda d: gr.Row.update(visible="Denoising strength" in d)),
@@ -639,12 +615,12 @@ def create_ui():
            ]
            ]
            parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
            parameters_copypaste.add_paste_fields("txt2img", None, txt2img_paste_fields, override_settings)
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
                paste_button=txt2img_paste, tabname="txt2img", source_text_component=txt2img_prompt, source_image_component=None,
                paste_button=toprow.paste, tabname="txt2img", source_text_component=toprow.prompt, source_image_component=None,
            ))
            ))


            txt2img_preview_params = [
            txt2img_preview_params = [
                txt2img_prompt,
                toprow.prompt,
                txt2img_negative_prompt,
                toprow.negative_prompt,
                steps,
                steps,
                sampler_index,
                sampler_index,
                cfg_scale,
                cfg_scale,
@@ -653,8 +629,8 @@ def create_ui():
                height,
                height,
            ]
            ]


            token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_prompt, steps], outputs=[token_counter])
            toprow.token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[txt2img_negative_prompt, steps], outputs=[negative_token_counter])
            toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])


            ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)
            ui_extra_networks.setup_ui(extra_networks_ui, txt2img_gallery)


@@ -662,13 +638,13 @@ def create_ui():
    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)
    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)


    with gr.Blocks(analytics_enabled=False) as img2img_interface:
    with gr.Blocks(analytics_enabled=False) as img2img_interface:
        img2img_prompt, img2img_prompt_styles, img2img_negative_prompt, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, extra_networks_button, token_counter, token_button, negative_token_counter, negative_token_button, restore_progress_button = create_toprow(is_img2img=True)
        toprow = img2img_toprow = Toprow(is_img2img=True)


        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)
        img2img_prompt_img = gr.File(label="", elem_id="img2img_prompt_image", file_count="single", type="binary", visible=False)


        with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
        with FormRow(variant='compact', elem_id="img2img_extra_networks", visible=False) as extra_networks:
            from modules import ui_extra_networks
            from modules import ui_extra_networks
            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, extra_networks_button, 'img2img')
            extra_networks_ui_img2img = ui_extra_networks.create_ui(extra_networks, toprow.extra_networks_button, 'img2img')


        with FormRow().style(equal_height=False):
        with FormRow().style(equal_height=False):
            with gr.Column(variant='compact', elem_id="img2img_settings"):
            with gr.Column(variant='compact', elem_id="img2img_settings"):
@@ -889,7 +865,7 @@ def create_ui():
                    img2img_prompt_img
                    img2img_prompt_img
                ],
                ],
                outputs=[
                outputs=[
                    img2img_prompt,
                    toprow.prompt,
                    img2img_prompt_img
                    img2img_prompt_img
                ],
                ],
                show_progress=False,
                show_progress=False,
@@ -901,9 +877,9 @@ def create_ui():
                inputs=[
                inputs=[
                    dummy_component,
                    dummy_component,
                    dummy_component,
                    dummy_component,
                    img2img_prompt,
                    toprow.prompt,
                    img2img_negative_prompt,
                    toprow.negative_prompt,
                    img2img_prompt_styles,
                    toprow.ui_styles.dropdown,
                    init_img,
                    init_img,
                    sketch,
                    sketch,
                    init_img_with_mask,
                    init_img_with_mask,
@@ -962,11 +938,11 @@ def create_ui():
                    inpaint_color_sketch,
                    inpaint_color_sketch,
                    init_img_inpaint,
                    init_img_inpaint,
                ],
                ],
                outputs=[img2img_prompt, dummy_component],
                outputs=[toprow.prompt, dummy_component],
            )
            )


            img2img_prompt.submit(**img2img_args)
            toprow.prompt.submit(**img2img_args)
            submit.click(**img2img_args)
            toprow.submit.click(**img2img_args)


            res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)
            res_switch_btn.click(fn=None, _js="function(){switchWidthHeight('img2img')}", inputs=None, outputs=None, show_progress=False)


@@ -978,7 +954,7 @@ def create_ui():
                show_progress=False,
                show_progress=False,
            )
            )


            restore_progress_button.click(
            toprow.restore_progress_button.click(
                fn=progress.restore_progress,
                fn=progress.restore_progress,
                _js="restoreProgressImg2img",
                _js="restoreProgressImg2img",
                inputs=[dummy_component],
                inputs=[dummy_component],
@@ -991,46 +967,24 @@ def create_ui():
                show_progress=False,
                show_progress=False,
            )
            )


            img2img_interrogate.click(
            toprow.button_interrogate.click(
                fn=lambda *args: process_interrogate(interrogate, *args),
                fn=lambda *args: process_interrogate(interrogate, *args),
                **interrogate_args,
                **interrogate_args,
            )
            )


            img2img_deepbooru.click(
            toprow.button_deepbooru.click(
                fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
                fn=lambda *args: process_interrogate(interrogate_deepbooru, *args),
                **interrogate_args,
                **interrogate_args,
            )
            )


            prompts = [(txt2img_prompt, txt2img_negative_prompt), (img2img_prompt, img2img_negative_prompt)]
            toprow.token_button.click(fn=update_token_counter, inputs=[toprow.prompt, steps], outputs=[toprow.token_counter])
            style_dropdowns = [txt2img_prompt_styles, img2img_prompt_styles]
            toprow.negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[toprow.negative_prompt, steps], outputs=[toprow.negative_token_counter])
            style_js_funcs = ["update_txt2img_tokens", "update_img2img_tokens"]

            for button, (prompt, negative_prompt) in zip([txt2img_save_style, img2img_save_style], prompts):
                button.click(
                    fn=add_style,
                    _js="ask_for_style_name",
                    # Have to pass empty dummy component here, because the JavaScript and Python function have to accept
                    # the same number of parameters, but we only know the style-name after the JavaScript prompt
                    inputs=[dummy_component, prompt, negative_prompt],
                    outputs=[txt2img_prompt_styles, img2img_prompt_styles],
                )

            for button, (prompt, negative_prompt), styles, js_func in zip([txt2img_prompt_style_apply, img2img_prompt_style_apply], prompts, style_dropdowns, style_js_funcs):
                button.click(
                    fn=apply_styles,
                    _js=js_func,
                    inputs=[prompt, negative_prompt, styles],
                    outputs=[prompt, negative_prompt, styles],
                )

            token_button.click(fn=update_token_counter, inputs=[img2img_prompt, steps], outputs=[token_counter])
            negative_token_button.click(fn=wrap_queued_call(update_token_counter), inputs=[img2img_negative_prompt, steps], outputs=[negative_token_counter])


            ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)
            ui_extra_networks.setup_ui(extra_networks_ui_img2img, img2img_gallery)


            img2img_paste_fields = [
            img2img_paste_fields = [
                (img2img_prompt, "Prompt"),
                (toprow.prompt, "Prompt"),
                (img2img_negative_prompt, "Negative prompt"),
                (toprow.negative_prompt, "Negative prompt"),
                (steps, "Steps"),
                (steps, "Steps"),
                (sampler_index, "Sampler"),
                (sampler_index, "Sampler"),
                (restore_faces, "Face restoration"),
                (restore_faces, "Face restoration"),
@@ -1044,7 +998,7 @@ def create_ui():
                (subseed_strength, "Variation seed strength"),
                (subseed_strength, "Variation seed strength"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_w, "Seed resize from-1"),
                (seed_resize_from_h, "Seed resize from-2"),
                (seed_resize_from_h, "Seed resize from-2"),
                (img2img_prompt_styles, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                (toprow.ui_styles.dropdown, lambda d: d["Styles array"] if isinstance(d.get("Styles array"), list) else gr.update()),
                (denoising_strength, "Denoising strength"),
                (denoising_strength, "Denoising strength"),
                (mask_blur, "Mask blur"),
                (mask_blur, "Mask blur"),
                *modules.scripts.scripts_img2img.infotext_fields
                *modules.scripts.scripts_img2img.infotext_fields
@@ -1052,7 +1006,7 @@ def create_ui():
            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields, override_settings)
            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields, override_settings)
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
            parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
                paste_button=img2img_paste, tabname="img2img", source_text_component=img2img_prompt, source_image_component=None,
                paste_button=toprow.paste, tabname="img2img", source_text_component=toprow.prompt, source_image_component=None,
            ))
            ))


    modules.scripts.scripts_current = None
    modules.scripts.scripts_current = None
+28 −4
Original line number Original line Diff line number Diff line
@@ -223,20 +223,44 @@ Requested path was: {f}




def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
def create_refresh_button(refresh_component, refresh_method, refreshed_args, elem_id):
    refresh_components = refresh_component if isinstance(refresh_component, list) else [refresh_component]

    label = None
    for comp in refresh_components:
        label = getattr(comp, 'label', None)
        if label is not None:
            break

    def refresh():
    def refresh():
        refresh_method()
        refresh_method()
        args = refreshed_args() if callable(refreshed_args) else refreshed_args
        args = refreshed_args() if callable(refreshed_args) else refreshed_args


        for k, v in args.items():
        for k, v in args.items():
            setattr(refresh_component, k, v)
            for comp in refresh_components:
                setattr(comp, k, v)


        return gr.update(**(args or {}))
        return [gr.update(**(args or {})) for _ in refresh_components]


    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id)
    refresh_button = ToolButton(value=refresh_symbol, elem_id=elem_id, tooltip=f"{label}: refresh" if label else "Refresh")
    refresh_button.click(
    refresh_button.click(
        fn=refresh,
        fn=refresh,
        inputs=[],
        inputs=[],
        outputs=[refresh_component]
        outputs=[*refresh_components]
    )
    )
    return refresh_button
    return refresh_button



def setup_dialog(button_show, dialog, *, button_close=None):
    """Sets up the UI so that the dialog (gr.Box) is invisible, and is only shown when buttons_show is clicked, in a fullscreen modal window."""

    dialog.visible = False

    button_show.click(
        fn=lambda: gr.update(visible=True),
        inputs=[],
        outputs=[dialog],
    ).then(fn=None, _js="function(){ popup(gradioApp().getElementById('" + dialog.elem_id + "')); }")

    if button_close:
        button_close.click(fn=None, _js="closePopup")
+1 −1
Original line number Original line Diff line number Diff line
@@ -12,7 +12,7 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
    def refresh(self):
    def refresh(self):
        shared.refresh_checkpoints()
        shared.refresh_checkpoints()


    def create_item(self, name, index=None):
    def create_item(self, name, index=None, enable_filter=True):
        checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
        checkpoint: sd_models.CheckpointInfo = sd_models.checkpoint_aliases.get(name)
        path, ext = os.path.splitext(checkpoint.filename)
        path, ext = os.path.splitext(checkpoint.filename)
        return {
        return {
Loading