Commit 2d5a5076 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Make it so that upscalers are not repeated when restarting UI.

parent e9fb9bb0
Loading
Loading
Loading
Loading
+20 −0
Original line number Diff line number Diff line
@@ -123,6 +123,23 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
        pass


builtin_upscaler_classes = []
forbidden_upscaler_classes = set()


def list_builtin_upscalers():
    load_upscalers()

    builtin_upscaler_classes.clear()
    builtin_upscaler_classes.extend(Upscaler.__subclasses__())


def forbid_loaded_nonbuiltin_upscalers():
    for cls in Upscaler.__subclasses__():
        if cls not in builtin_upscaler_classes:
            forbidden_upscaler_classes.add(cls)


def load_upscalers():
    # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
    # so we'll try to import any _model.py files before looking in __subclasses__
@@ -139,6 +156,9 @@ def load_upscalers():
    datas = []
    commandline_options = vars(shared.cmd_opts)
    for cls in Upscaler.__subclasses__():
        if cls in forbidden_upscaler_classes:
            continue

        name = cls.__name__
        cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
        scaler = cls(commandline_options.get(cmd_name, None))
+7 −7
Original line number Diff line number Diff line
import os
import sys
import threading
import time
import importlib
@@ -55,8 +56,8 @@ def initialize():
    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
    shared.face_restorers.append(modules.face_restoration.FaceRestoration())

    modelloader.list_builtin_upscalers()
    modules.scripts.load_scripts()

    modelloader.load_upscalers()

    modules.sd_vae.refresh_vae_list()
@@ -169,23 +170,22 @@ def webui():
        modules.script_callbacks.app_started_callback(shared.demo, app)

        wait_on_server(shared.demo)
        print('Restarting UI...')

        sd_samplers.set_samplers()

        print('Reloading extensions')
        extensions.list_extensions()

        localization.list_localizations(cmd_opts.localizations_dir)

        print('Reloading custom scripts')
        modelloader.forbid_loaded_nonbuiltin_upscalers()
        modules.scripts.reload_scripts()
        modelloader.load_upscalers()

        print('Reloading modules: modules.ui')
        importlib.reload(modules.ui)
        print('Refreshing Model List')
        for module in [module for name, module in sys.modules.items() if name.startswith("modules.ui")]:
            importlib.reload(module)

        modules.sd_models.list_models()
        print('Restarting Gradio')


if __name__ == "__main__":