Commit 65ed4421 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add callback for when the script is unloaded

parent c9bded39
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -71,6 +71,7 @@ callback_map = dict(
    callbacks_before_component=[],
    callbacks_after_component=[],
    callbacks_image_grid=[],
    callbacks_script_unloaded=[],
)


@@ -171,6 +172,14 @@ def image_grid_callback(params: ImageGridLoopParams):
            report_exception(c, 'image_grid')


def script_unloaded_callback():
    for c in reversed(callback_map['callbacks_script_unloaded']):
        try:
            c.callback()
        except Exception:
            report_exception(c, 'script_unloaded')


def add_callback(callbacks, fun):
    stack = [x for x in inspect.stack() if x.filename != __file__]
    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -202,7 +211,7 @@ def on_app_started(callback):

def on_model_loaded(callback):
    """register a function to be called when the stable diffusion model is created; the model is
    passed as an argument"""
    passed as an argument; this function is also called when the script is reloaded. """
    add_callback(callback_map['callbacks_model_loaded'], callback)


@@ -279,3 +288,10 @@ def on_image_grid(callback):
       - params: ImageGridLoopParams - parameters to be used for grid creation. Can be modified.
    """
    add_callback(callback_map['callbacks_image_grid'], callback)


def on_script_unloaded(callback):
    """register a function to be called before the script is unloaded. Any hooks/hijacks/monkeying about that
    the script did should be reverted here"""

    add_callback(callback_map['callbacks_script_unloaded'], callback)
+2 −0
Original line number Diff line number Diff line
@@ -187,12 +187,14 @@ def webui():

        sd_samplers.set_samplers()

        modules.script_callbacks.script_unloaded_callback()
        extensions.list_extensions()

        localization.list_localizations(cmd_opts.localizations_dir)

        modelloader.forbid_loaded_nonbuiltin_upscalers()
        modules.scripts.reload_scripts()
        modules.script_callbacks.model_loaded_callback(shared.sd_model)
        modelloader.load_upscalers()

        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]: