Commit 7df7e4d2 authored by space-nuko's avatar space-nuko
Browse files

Allow extensions to declare paste fields for "Send to X" buttons

parent 3715ece0
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
@@ -23,13 +23,14 @@ registered_param_bindings = []




class ParamBinding:
class ParamBinding:
    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None):
    def __init__(self, paste_button, tabname, source_text_component=None, source_image_component=None, source_tabname=None, override_settings_component=None, paste_field_names=[]):
        self.paste_button = paste_button
        self.paste_button = paste_button
        self.tabname = tabname
        self.tabname = tabname
        self.source_text_component = source_text_component
        self.source_text_component = source_text_component
        self.source_image_component = source_image_component
        self.source_image_component = source_image_component
        self.source_tabname = source_tabname
        self.source_tabname = source_tabname
        self.override_settings_component = override_settings_component
        self.override_settings_component = override_settings_component
        self.paste_field_names = paste_field_names




def reset():
def reset():
@@ -133,7 +134,7 @@ def connect_paste_params_buttons():
            connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)
            connect_paste(binding.paste_button, fields, binding.source_text_component, binding.override_settings_component, binding.tabname)


        if binding.source_tabname is not None and fields is not None:
        if binding.source_tabname is not None and fields is not None:
            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else [])
            paste_field_names = ['Prompt', 'Negative prompt', 'Steps', 'Face restoration'] + (["Seed"] if shared.opts.send_seed else []) + binding.paste_field_names
            binding.paste_button.click(
            binding.paste_button.click(
                fn=lambda *x: x,
                fn=lambda *x: x,
                inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
                inputs=[field for field, name in paste_fields[binding.source_tabname]["fields"] if name in paste_field_names],
+9 −0
Original line number Original line Diff line number Diff line
@@ -33,6 +33,11 @@ class Script:
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    parsing infotext to set the value for the component; see ui.py's txt2img_paste_fields for an example
    """
    """


    paste_field_names = None
    """if set in ui(), this is a list of names of infotext fields; the fields will be sent through the
    various "Send to <X>" buttons when clicked
    """

    def title(self):
    def title(self):
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""
        """this function should return the title of the script. This is what will be displayed in the dropdown menu."""


@@ -256,6 +261,7 @@ class ScriptRunner:
        self.alwayson_scripts = []
        self.alwayson_scripts = []
        self.titles = []
        self.titles = []
        self.infotext_fields = []
        self.infotext_fields = []
        self.paste_field_names = []


    def initialize_scripts(self, is_img2img):
    def initialize_scripts(self, is_img2img):
        from modules import scripts_auto_postprocessing
        from modules import scripts_auto_postprocessing
@@ -304,6 +310,9 @@ class ScriptRunner:
            if script.infotext_fields is not None:
            if script.infotext_fields is not None:
                self.infotext_fields += script.infotext_fields
                self.infotext_fields += script.infotext_fields


            if script.paste_field_names is not None:
                self.paste_field_names += script.paste_field_names

            inputs += controls
            inputs += controls
            inputs_alwayson += [script.alwayson for _ in controls]
            inputs_alwayson += [script.alwayson for _ in controls]
            script.args_to = len(inputs)
            script.args_to = len(inputs)
+8 −1
Original line number Original line Diff line number Diff line
@@ -198,9 +198,16 @@ Requested path was: {f}
                html_info = gr.HTML(elem_id=f'html_info_{tabname}')
                html_info = gr.HTML(elem_id=f'html_info_{tabname}')
                html_log = gr.HTML(elem_id=f'html_log_{tabname}')
                html_log = gr.HTML(elem_id=f'html_log_{tabname}')


            paste_field_names = []
            if tabname == "txt2img":
                paste_field_names = modules.scripts.scripts_txt2img.paste_field_names
            elif tabname == "img2img":
                paste_field_names = modules.scripts.scripts_img2img.paste_field_names

            for paste_tabname, paste_button in buttons.items():
            for paste_tabname, paste_button in buttons.items():
                parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
                parameters_copypaste.register_paste_params_button(parameters_copypaste.ParamBinding(
                    paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery
                    paste_button=paste_button, tabname=paste_tabname, source_tabname="txt2img" if tabname == "txt2img" else None, source_image_component=result_gallery,
                    paste_field_names=paste_field_names
                ))
                ))


            return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log
            return result_gallery, generation_info if tabname != "extras" else html_info_x, html_info, html_log