Commit 1610b325 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add callback for creating a tab in train UI

parent 8011be33
Loading
Loading
Loading
Loading
+25 −2
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ from typing import Optional
from fastapi import FastAPI
from gradio import Blocks


def report_exception(c, job):
    print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
    print(traceback.format_exc(), file=sys.stderr)
@@ -45,15 +46,21 @@ class CFGDenoiserParams:
        """Total number of sampling steps planned"""


class UiTrainTabParams:
    def __init__(self, txt2img_preview_params):
        self.txt2img_preview_params = txt2img_preview_params


ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callback_map = dict(
    callbacks_app_started=[],
    callbacks_model_loaded=[],
    callbacks_ui_tabs=[],
    callbacks_ui_train_tabs=[],
    callbacks_ui_settings=[],
    callbacks_before_image_saved=[],
    callbacks_image_saved=[],
    callbacks_cfg_denoiser=[]
    callbacks_cfg_denoiser=[],
)


@@ -61,6 +68,7 @@ def clear_callbacks():
    for callback_list in callback_map.values():
        callback_list.clear()


def app_started_callback(demo: Optional[Blocks], app: FastAPI):
    for c in callback_map['callbacks_app_started']:
        try:
@@ -89,6 +97,14 @@ def ui_tabs_callback():
    return res


def ui_train_tabs_callback(params: UiTrainTabParams):
    for c in callback_map['callbacks_ui_train_tabs']:
        try:
            c.callback(params)
        except Exception:
            report_exception(c, 'callbacks_ui_train_tabs')


def ui_settings_callback():
    for c in callback_map['callbacks_ui_settings']:
        try:
@@ -169,6 +185,13 @@ def on_ui_tabs(callback):
    add_callback(callback_map['callbacks_ui_tabs'], callback)


def on_ui_train_tabs(callback):
    """register a function to be called when the UI is creating new tabs for the train tab.
    Create your new tabs with gr.Tab.
    """
    add_callback(callback_map['callbacks_ui_train_tabs'], callback)


def on_ui_settings(callback):
    """register a function to be called before UI settings are populated; add your settings
    by using shared.opts.add_option(shared.OptionInfo(...)) """
+4 −0
Original line number Diff line number Diff line
@@ -1270,6 +1270,10 @@ def create_ui(wrap_gradio_gpu_call):
                        train_hypernetwork = gr.Button(value="Train Hypernetwork", variant='primary')
                        train_embedding = gr.Button(value="Train Embedding", variant='primary')

                params = script_callbacks.UiTrainTabParams(txt2img_preview_params)

                script_callbacks.ui_train_tabs_callback(params)

            with gr.Column():
                progressbar = gr.HTML(elem_id="ti_progressbar")
                ti_output = gr.Text(elem_id="ti_output", value="", show_label=False)