Commit 71f4a4af authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Deduplicate webui.py initial-load/reload code

parent 0f28aee9
Loading
Loading
Loading
Loading
+0 −1
Original line number Diff line number Diff line
@@ -98,7 +98,6 @@ def setup_model():
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    list_models()
    enable_midas_autodownload()


+34 −50
Original line number Diff line number Diff line
@@ -15,6 +15,7 @@ from fastapi.middleware.gzip import GZipMiddleware
from packaging import version

import logging

logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

from modules import paths, timer, import_hook, errors  # noqa: F401
@@ -231,9 +232,27 @@ def initialize():
    validate_tls_options()
    configure_sigint_handler()
    check_versions()
    modelloader.cleanup_models()
    configure_opts_onchange()

    modules.sd_models.setup_model()
    startup_timer.record("setup SD model")

    codeformer.setup_model(cmd_opts.codeformer_models_path)
    startup_timer.record("setup codeformer")

    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
    startup_timer.record("setup gfpgan")

    initialize_rest(reload_script_modules=False)


def initialize_rest(*, reload_script_modules=False):
    """
    Called both from initialize() and when reloading the webui.
    """
    sd_samplers.set_samplers()
    extensions.list_extensions()
    localization.list_localizations(cmd_opts.localizations_dir)
    startup_timer.record("list extensions")

    restore_config_state_file()
@@ -243,42 +262,40 @@ def initialize():
        modules.scripts.load_scripts()
        return

    modelloader.cleanup_models()
    modules.sd_models.setup_model()
    modules.sd_models.list_models()
    startup_timer.record("list SD models")

    codeformer.setup_model(cmd_opts.codeformer_models_path)
    startup_timer.record("setup codeformer")

    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
    startup_timer.record("setup gfpgan")
    localization.list_localizations(cmd_opts.localizations_dir)

    modules.scripts.load_scripts()
    startup_timer.record("load scripts")

    if reload_script_modules:
        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
            importlib.reload(module)
        startup_timer.record("reload script modules")

    modelloader.load_upscalers()
    startup_timer.record("load upscalers")

    modules.sd_vae.refresh_vae_list()
    startup_timer.record("refresh VAE")

    modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
    startup_timer.record("refresh textual inversion templates")

    # load model in parallel to other startup stuff
    # (when reloading, this does nothing)
    Thread(target=lambda: shared.sd_model).start()

    shared.reload_hypernetworks()
    startup_timer.record("reload hypernets")
    startup_timer.record("reload hypernetworks")

    ui_extra_networks.initialize()
    ui_extra_networks.register_default_pages()

    extra_networks.initialize()
    extra_networks.register_default_extra_networks()

    startup_timer.record("extra networks")

    startup_timer.record("initialize extra networks")


def setup_middleware(app):
@@ -423,45 +440,12 @@ def webui():
        print('Restarting UI...')
        shared.demo.close()
        time.sleep(0.5)
        modules.script_callbacks.app_reload_callback()

        startup_timer.reset()

        sd_samplers.set_samplers()

        modules.script_callbacks.app_reload_callback()
        startup_timer.record("app reload callback")
        modules.script_callbacks.script_unloaded_callback()
        extensions.list_extensions()
        startup_timer.record("list extensions")

        restore_config_state_file()

        localization.list_localizations(cmd_opts.localizations_dir)

        modules.scripts.reload_scripts()
        startup_timer.record("load scripts")

        modules.script_callbacks.model_loaded_callback(shared.sd_model)
        startup_timer.record("model loaded callback")

        modelloader.load_upscalers()
        startup_timer.record("load upscalers")

        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
            importlib.reload(module)
        startup_timer.record("reload script modules")

        modules.sd_models.list_models()
        startup_timer.record("list SD models")

        shared.reload_hypernetworks()
        startup_timer.record("reload hypernetworks")

        ui_extra_networks.initialize()
        ui_extra_networks.register_default_pages()

        extra_networks.initialize()
        extra_networks.register_default_extra_networks()
        startup_timer.record("initialize extra networks")
        startup_timer.record("scripts unloaded callback")
        initialize_rest(reload_script_modules=True)


if __name__ == "__main__":