Commit a1a37633 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make existing script loading and new preload code use same code for loading modules

limit extension preload scripts to just one file named preload.py
parent e5690d0b
Loading
Loading
Loading
Loading
+0 −21
Original line number Diff line number Diff line
import os
import sys
import traceback
from importlib.machinery import SourceFileLoader

import git

@@ -85,23 +84,3 @@ def list_extensions():
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
        extensions.append(extension)

def preload_extensions(parser):
    if not os.path.isdir(extensions_dir):
        return

    for dirname in sorted(os.listdir(extensions_dir)):
        path = os.path.join(extensions_dir, dirname)
        if not os.path.isdir(path):
            continue
        for file in os.listdir(path):
            if "preload.py" in file:
                full_file = os.path.join(path, file)
                print(f"Got preload file: {full_file}")

                try:
                    ext = SourceFileLoader("preload", full_file).load_module()
                    parser = ext.preload(parser)
                except Exception as e:
                    print(f"Exception preloading script: {e}")
    return parser
 No newline at end of file
+34 −0
Original line number Diff line number Diff line
import os
import sys
import traceback
from types import ModuleType


def load_module(path):
    with open(path, "r", encoding="utf8") as file:
        text = file.read()

    compiled = compile(text, path, 'exec')
    module = ModuleType(os.path.basename(path))
    exec(compiled, module.__dict__)

    return module


def preload_extensions(extensions_dir, parser):
    if not os.path.isdir(extensions_dir):
        return

    for dirname in sorted(os.listdir(extensions_dir)):
        preload_script = os.path.join(extensions_dir, dirname, "preload.py")
        if not os.path.isfile(preload_script):
            continue

        try:
            module = load_module(preload_script)
            if hasattr(module, 'preload'):
                module.preload(parser)

        except Exception:
            print(f"Error running preload() for {preload_script}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
+17 −29
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from collections import namedtuple
import gradio as gr

from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions
from modules import shared, paths, script_callbacks, extensions, script_loading

AlwaysVisible = object()

@@ -161,13 +161,7 @@ def load_scripts():
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

            with open(scriptfile.path, "r", encoding="utf8") as file:
                text = file.read()

            from types import ModuleType
            compiled = compile(text, scriptfile.path, 'exec')
            module = ModuleType(scriptfile.filename)
            exec(compiled, module.__dict__)
            module = script_loading.load_module(scriptfile.path)

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
@@ -328,19 +322,13 @@ class ScriptRunner:

    def reload_sources(self, cache):
        for si, script in list(enumerate(self.scripts)):
            with open(script.filename, "r", encoding="utf8") as file:
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename
                text = file.read()

                from types import ModuleType

            module = cache.get(filename, None)
            if module is None:
                    compiled = compile(text, filename, 'exec')
                    module = ModuleType(script.filename)
                    exec(compiled, module.__dict__)
                module = script_loading.load_module(script.filename)
                cache[filename] = module

            for key, script_class in module.__dict__.items():
+2 −3
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import datetime
import json
import os
import sys
from collections import OrderedDict
import time

import gradio as gr
@@ -15,7 +14,7 @@ import modules.memmon
import modules.sd_models
import modules.styles
import modules.devices as devices
from modules import sd_samplers, sd_models, localization, sd_vae, extensions
from modules import sd_samplers, sd_models, localization, sd_vae, extensions, script_loading
from modules.hypernetworks import hypernetwork
from modules.paths import models_path, script_path, sd_path

@@ -91,7 +90,7 @@ parser.add_argument("--tls-keyfile", type=str, help="Partially enables TLS, requ
parser.add_argument("--tls-certfile", type=str, help="Partially enables TLS, requires --tls-keyfile to fully function", default=None)
parser.add_argument("--server-name", type=str, help="Sets hostname of server", default=None)

extensions.preload_extensions(parser)
script_loading.preload_extensions(extensions.extensions_dir, parser)

cmd_opts = parser.parse_args()