Unverified Commit 423f2222 authored by Maiko Tan's avatar Maiko Tan
Browse files

feat: add app started callback

parent 17a2076f
Loading
Loading
Loading
Loading
+15 −0
Original line number Diff line number Diff line
@@ -3,6 +3,8 @@ import traceback
from collections import namedtuple
import inspect

from fastapi import FastAPI
from gradio import Blocks

def report_exception(c, job):
    print(f"Error executing callback {job} for {c.script}", file=sys.stderr)
@@ -25,6 +27,7 @@ class ImageSaveParams:


ScriptCallback = namedtuple("ScriptCallback", ["script", "callback"])
callbacks_app_started = []
callbacks_model_loaded = []
callbacks_ui_tabs = []
callbacks_ui_settings = []
@@ -40,6 +43,14 @@ def clear_callbacks():
    callbacks_image_saved.clear()


def app_started_callback(demo: Blocks, app: FastAPI):
    for c in callbacks_app_started:
        try:
            c.callback(demo, app)
        except Exception:
            report_exception(c, 'app_started_callback')


def model_loaded_callback(sd_model):
    for c in callbacks_model_loaded:
        try:
@@ -91,6 +102,10 @@ def add_callback(callbacks, fun):
    callbacks.append(ScriptCallback(filename, fun))


def on_app_started(callback):
    add_callback(callbacks_app_started, callback)


def on_model_loaded(callback):
    """register a function to be called when the stable diffusion model is created; the model is
    passed as an argument"""
+3 −0
Original line number Diff line number Diff line
@@ -23,6 +23,7 @@ import modules.sd_hijack
import modules.sd_models
import modules.shared as shared
import modules.txt2img
import modules.script_callbacks

import modules.ui
from modules import devices
@@ -135,6 +136,8 @@ def webui():
        if (launch_api):
            create_api(app)

        modules.script_callbacks.app_started_callback(demo, app)

        wait_on_server(demo)

        sd_samplers.set_samplers()