Commit 3596af07 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Add API for scripts to add elements anywhere in UI.

parent ccd73fc1
Loading
Loading
Loading
Loading
+35 −0
Original line number Diff line number Diff line
@@ -61,6 +61,8 @@ callback_map = dict(
    callbacks_before_image_saved=[],
    callbacks_image_saved=[],
    callbacks_cfg_denoiser=[],
    callbacks_before_component=[],
    callbacks_after_component=[],
)


@@ -137,6 +139,22 @@ def cfg_denoiser_callback(params: CFGDenoiserParams):
            report_exception(c, 'cfg_denoiser_callback')


def before_component_callback(component, **kwargs):
    for c in callback_map['callbacks_before_component']:
        try:
            c.callback(component, **kwargs)
        except Exception:
            report_exception(c, 'before_component_callback')


def after_component_callback(component, **kwargs):
    for c in callback_map['callbacks_after_component']:
        try:
            c.callback(component, **kwargs)
        except Exception:
            report_exception(c, 'after_component_callback')


def add_callback(callbacks, fun):
    stack = [x for x in inspect.stack() if x.filename != __file__]
    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -220,3 +238,20 @@ def on_cfg_denoiser(callback):
        - params: CFGDenoiserParams - parameters to be passed to the inner model and sampling state details.
    """
    add_callback(callback_map['callbacks_cfg_denoiser'], callback)


def on_before_component(callback):
    """register a function to be called before a component is created.
    The callback is called with arguments:
        - component - gradio component that is about to be created.
        - **kwargs - args to gradio.components.IOComponent.__init__ function

    Use elem_id/label fields of kwargs to figure out which component it is.
    This can be useful to inject your own components somewhere in the middle of vanilla UI.
    """
    add_callback(callback_map['callbacks_before_component'], callback)


def on_after_component(callback):
    """register a function to be called after a component is created. See on_before_component for more."""
    add_callback(callback_map['callbacks_after_component'], callback)
+66 −3
Original line number Diff line number Diff line
@@ -17,6 +17,9 @@ class Script:
    args_to = None
    alwayson = False

    is_txt2img = False
    is_img2img = False

    """A gr.Group component that has all script's UI inside it"""
    group = None

@@ -93,6 +96,23 @@ class Script:

        pass

    def before_component(self, component, **kwargs):
        """
        Called before a component is created.
        Use elem_id/label fields of kwargs to figure out which component it is.
        This can be useful to inject your own components somewhere in the middle of vanilla UI.
        You can return created components in the ui() function to add them to the list of arguments for your processing functions
        """

        pass

    def after_component(self, component, **kwargs):
        """
        Called after a component is created. Same as above.
        """

        pass

    def describe(self):
        """unused"""
        return ""
@@ -195,12 +215,18 @@ class ScriptRunner:
        self.titles = []
        self.infotext_fields = []

    def setup_ui(self, is_img2img):
    def initialize_scripts(self, is_img2img):
        self.scripts.clear()
        self.alwayson_scripts.clear()
        self.selectable_scripts.clear()

        for script_class, path, basedir in scripts_data:
            script = script_class()
            script.filename = path
            script.is_txt2img = not is_img2img
            script.is_img2img = is_img2img

            visibility = script.show(is_img2img)
            visibility = script.show(script.is_img2img)

            if visibility == AlwaysVisible:
                self.scripts.append(script)
@@ -211,6 +237,7 @@ class ScriptRunner:
                self.scripts.append(script)
                self.selectable_scripts.append(script)

    def setup_ui(self):
        self.titles = [wrap_call(script.title, script.filename, "title") or f"{script.filename} [error]" for script in self.selectable_scripts]

        inputs = [None]
@@ -220,7 +247,7 @@ class ScriptRunner:
            script.args_from = len(inputs)
            script.args_to = len(inputs)

            controls = wrap_call(script.ui, script.filename, "ui", is_img2img)
            controls = wrap_call(script.ui, script.filename, "ui", script.is_img2img)

            if controls is None:
                return
@@ -320,6 +347,22 @@ class ScriptRunner:
                print(f"Error running postprocess: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def before_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.before_component(component, **kwargs)
            except Exception:
                print(f"Error running before_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def after_component(self, component, **kwargs):
        for script in self.scripts:
            try:
                script.after_component(component, **kwargs)
            except Exception:
                print(f"Error running after_component: {script.filename}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)

    def reload_sources(self, cache):
        for si, script in list(enumerate(self.scripts)):
            args_from = script.args_from
@@ -341,6 +384,7 @@ class ScriptRunner:

scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
scripts_current: ScriptRunner = None


def reload_script_body_only():
@@ -357,3 +401,22 @@ def reload_scripts():
    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()


def IOComponent_init(self, *args, **kwargs):
    if scripts_current is not None:
        scripts_current.before_component(self, **kwargs)

    script_callbacks.before_component_callback(self, **kwargs)

    res = original_IOComponent_init(self, *args, **kwargs)

    script_callbacks.after_component_callback(self, **kwargs)

    if scripts_current is not None:
        scripts_current.after_component(self, **kwargs)

    return res


original_IOComponent_init = gr.components.IOComponent.__init__
gr.components.IOComponent.__init__ = IOComponent_init
+10 −2
Original line number Diff line number Diff line
@@ -695,6 +695,9 @@ def create_ui(wrap_gradio_gpu_call):

    parameters_copypaste.reset()

    modules.scripts.scripts_current = modules.scripts.scripts_txt2img
    modules.scripts.scripts_txt2img.initialize_scripts(is_img2img=False)

    with gr.Blocks(analytics_enabled=False) as txt2img_interface:
        txt2img_prompt, roll, txt2img_prompt_style, txt2img_negative_prompt, txt2img_prompt_style2, submit, _, _, txt2img_prompt_style_apply, txt2img_save_style, txt2img_paste, token_counter, token_button = create_toprow(is_img2img=False)
        dummy_component = gr.Label(visible=False)
@@ -737,7 +740,7 @@ def create_ui(wrap_gradio_gpu_call):
                seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()

                with gr.Group():
                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui(is_img2img=False)
                    custom_inputs = modules.scripts.scripts_txt2img.setup_ui()

            txt2img_gallery, generation_info, html_info = create_output_panel("txt2img", opts.outdir_txt2img_samples)
            parameters_copypaste.bind_buttons({"txt2img": txt2img_paste}, None, txt2img_prompt)
@@ -846,6 +849,9 @@ def create_ui(wrap_gradio_gpu_call):

            token_button.click(fn=update_token_counter, inputs=[txt2img_prompt, steps], outputs=[token_counter])

    modules.scripts.scripts_current = modules.scripts.scripts_img2img
    modules.scripts.scripts_img2img.initialize_scripts(is_img2img=True)

    with gr.Blocks(analytics_enabled=False) as img2img_interface:
        img2img_prompt, roll, img2img_prompt_style, img2img_negative_prompt, img2img_prompt_style2, submit, img2img_interrogate, img2img_deepbooru, img2img_prompt_style_apply, img2img_save_style, img2img_paste, token_counter, token_button = create_toprow(is_img2img=True)

@@ -916,7 +922,7 @@ def create_ui(wrap_gradio_gpu_call):
                seed, reuse_seed, subseed, reuse_subseed, subseed_strength, seed_resize_from_h, seed_resize_from_w, seed_checkbox = create_seed_inputs()

                with gr.Group():
                    custom_inputs = modules.scripts.scripts_img2img.setup_ui(is_img2img=True)
                    custom_inputs = modules.scripts.scripts_img2img.setup_ui()

            img2img_gallery, generation_info, html_info = create_output_panel("img2img", opts.outdir_img2img_samples)
            parameters_copypaste.bind_buttons({"img2img": img2img_paste}, None, img2img_prompt)
@@ -1065,6 +1071,8 @@ def create_ui(wrap_gradio_gpu_call):
            parameters_copypaste.add_paste_fields("img2img", init_img, img2img_paste_fields)
            parameters_copypaste.add_paste_fields("inpaint", init_img_with_mask, img2img_paste_fields)

    modules.scripts.scripts_current = None

    with gr.Blocks(analytics_enabled=False) as extras_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):