Commit 0fc7dc1c authored by wfjsw's avatar wfjsw
Browse files

implementing script metadata and DAG sorting mechanism

parent 5e80d9ee
Loading
Loading
Loading
Loading
+71 −9
Original line number Original line Diff line number Diff line
import configparser
import functools
import os
import os
import threading
import threading


@@ -23,8 +25,9 @@ class Extension:
    lock = threading.Lock()
    lock = threading.Lock()
    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']


    def __init__(self, name, path, enabled=True, is_builtin=False):
    def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
        self.name = name
        self.name = name
        self.canonical_name = canonical_name or name.lower()
        self.path = path
        self.path = path
        self.enabled = enabled
        self.enabled = enabled
        self.status = ''
        self.status = ''
@@ -37,6 +40,17 @@ class Extension:
        self.remote = None
        self.remote = None
        self.have_info_from_repo = False
        self.have_info_from_repo = False


    @functools.cached_property
    def metadata(self):
        if os.path.isfile(os.path.join(self.path, "sd_webui_metadata.ini")):
            try:
                config = configparser.ConfigParser()
                config.read(os.path.join(self.path, "sd_webui_metadata.ini"))
                return config
            except Exception:
                errors.report(f"Error reading sd_webui_metadata.ini for extension {self.canonical_name}.", exc_info=True)
        return None

    def to_dict(self):
    def to_dict(self):
        return {x: getattr(self, x) for x in self.cached_fields}
        return {x: getattr(self, x) for x in self.cached_fields}


@@ -136,9 +150,6 @@ class Extension:
def list_extensions():
def list_extensions():
    extensions.clear()
    extensions.clear()


    if not os.path.isdir(extensions_dir):
        return

    if shared.cmd_opts.disable_all_extensions:
    if shared.cmd_opts.disable_all_extensions:
        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
    elif shared.opts.disable_all_extensions == "all":
    elif shared.opts.disable_all_extensions == "all":
@@ -148,18 +159,69 @@ def list_extensions():
    elif shared.opts.disable_all_extensions == "extra":
    elif shared.opts.disable_all_extensions == "extra":
        print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
        print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")


    extension_paths = []
    extension_dependency_map = {}

    # scan through extensions directory and load metadata
    for dirname in [extensions_dir, extensions_builtin_dir]:
    for dirname in [extensions_dir, extensions_builtin_dir]:
        if not os.path.isdir(dirname):
        if not os.path.isdir(dirname):
            return
            continue


        for extension_dirname in sorted(os.listdir(dirname)):
        for extension_dirname in sorted(os.listdir(dirname)):
            path = os.path.join(dirname, extension_dirname)
            path = os.path.join(dirname, extension_dirname)
            if not os.path.isdir(path):
            if not os.path.isdir(path):
                continue
                continue


            extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
            canonical_name = extension_dirname
            requires = None

            if os.path.isfile(os.path.join(path, "sd_webui_metadata.ini")):
                try:
                    config = configparser.ConfigParser()
                    config.read(os.path.join(path, "sd_webui_metadata.ini"))
                    canonical_name = config.get("Extension", "Name", fallback=canonical_name)
                    requires = config.get("Extension", "Requires", fallback=None)
                    continue
                except Exception:
                    errors.report(f"Error reading sd_webui_metadata.ini for extension {extension_dirname}. "
                                  f"Will load regardless.", exc_info=True)

            canonical_name = canonical_name.lower().strip()

            # check for duplicated canonical names
            if canonical_name in extension_dependency_map:
                errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
                              f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
                              f"The current loading extension will be discarded.", exc_info=False)
                continue


    for dirname, path, is_builtin in extension_paths:
            # we want to wash the data to lowercase and remove whitespaces just in case
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
            requires = [x.strip() for x in requires.lower().split(',')] if requires else []

            extension_dependency_map[canonical_name] = {
                "dirname": extension_dirname,
                "path": path,
                "requires": requires,
            }

    # check for requirements
    for (_, extension_data) in extension_dependency_map.items():
        dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
        requirement_met = True
        for req in requires:
            if req not in extension_dependency_map:
                errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break
            dep_dirname = extension_dependency_map[req]['dirname']
            if dep_dirname in shared.opts.disabled_extensions:
                errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break

        is_builtin = dirname == extensions_builtin_dir
        extension = Extension(name=dirname, path=path,
                              enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
                              is_builtin=is_builtin)
        extensions.append(extension)
        extensions.append(extension)
+126 −15
Original line number Original line Diff line number Diff line
@@ -2,6 +2,7 @@ import os
import re
import re
import sys
import sys
import inspect
import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from collections import namedtuple
from dataclasses import dataclass
from dataclasses import dataclass


@@ -314,15 +315,131 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi


def list_scripts(scriptdirname, extension, *, include_extensions=True):
def list_scripts(scriptdirname, extension, *, include_extensions=True):
    scripts_list = []
    scripts_list = []

    script_dependency_map = {}
    basedir = os.path.join(paths.script_path, scriptdirname)

    if os.path.exists(basedir):
    # build script dependency map
        for filename in sorted(os.listdir(basedir)):

            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
    root_script_basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(root_script_basedir):
        for filename in sorted(os.listdir(root_script_basedir)):
            script_dependency_map[filename] = {
                "extension": None,
                "extension_dirname": None,
                "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
                "requires": [],
                "load_before": [],
                "load_after": [],
            }


    if include_extensions:
    if include_extensions:
        for ext in extensions.active():
        for ext in extensions.active():
            scripts_list += ext.list_files(scriptdirname, extension)
            extension_scripts_list = ext.list_files(scriptdirname, extension)
            for extension_script in extension_scripts_list:
                # this is built on the assumption that script name is unique.
                # I think bad thing is gonna happen if name collide in the current implementation anyway, but we
                # will need to refactor here if this assumption is broken later on.
                if extension_script.filename in script_dependency_map:
                    errors.report(f"Duplicate script name \"{extension_script.filename}\" found in extensions "
                                  f"\"{ext.name}\" and \"{script_dependency_map[extension_script.filename]['extension_dirname'] or 'builtin'}\". "
                                  f"The current loading file will be discarded.", exc_info=False)
                    continue

                relative_path = scriptdirname + "/" + extension_script.filename

                requires = None
                load_before = None
                load_after = None

                if ext.metadata is not None:
                    requires = ext.metadata.get(relative_path, "Requires", fallback=None)
                    load_before = ext.metadata.get(relative_path, "Before", fallback=None)
                    load_after = ext.metadata.get(relative_path, "After", fallback=None)

                requires = [x.strip() for x in requires.split(',')] if requires else []
                load_after = [x.strip() for x in load_after.split(',')] if load_after else []
                load_before = [x.strip() for x in load_before.split(',')] if load_before else []

                script_dependency_map[extension_script.filename] = {
                    "extension": ext.canonical_name,
                    "extension_dirname": ext.name,
                    "script_file": extension_script,
                    "requires": requires,
                    "load_before": load_before,
                    "load_after": load_after,
                }

    # resolve dependencies

    loaded_extensions = set()
    for _, script_data in script_dependency_map.items():
        if script_data['extension'] is not None:
            loaded_extensions.add(script_data['extension'])

    for script_filename, script_data in script_dependency_map.items():
        # load before requires inverse dependency
        # in this case, append the script name into the load_after list of the specified script
        for load_before_script in script_data['load_before']:
            if load_before_script.startswith('ext:'):
                # if this requires an extension to be loaded before
                required_extension = load_before_script[4:]
                for _, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == required_extension:
                        script_data2['load_after'].append(script_filename)
                        break
            else:
                # if this requires an individual script to be loaded before
                if load_before_script in script_dependency_map:
                    script_dependency_map[load_before_script]['load_after'].append(script_filename)

        # resolve extension name in load_after lists
        for load_after_script in script_data['load_after']:
            if load_after_script.startswith('ext:'):
                # if this requires an extension to be loaded after
                required_extension = load_after_script[4:]
                for script_file_name2, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == required_extension:
                        script_data['load_after'].append(script_file_name2)

        # remove all extension names in load_after lists
        script_data['load_after'] = [x for x in script_data['load_after'] if not x.startswith('ext:')]

    # build the DAG
    sorter = TopologicalSorter()
    for script_filename, script_data in script_dependency_map.items():
        requirement_met = True
        for required_script in script_data['requires']:
            if required_script.startswith('ext:'):
                # if this requires an extension to be installed
                required_extension = required_script[4:]
                if required_extension not in loaded_extensions:
                    errors.report(f"Script \"{script_filename}\" requires extension \"{required_extension}\" to "
                                  f"be loaded, but it is not. Skipping.",
                                  exc_info=False)
                    requirement_met = False
                    break
            else:
                # if this requires an individual script to be loaded
                if required_script not in script_dependency_map:
                    errors.report(f"Script \"{script_filename}\" requires script \"{required_script}\" to "
                                  f"be loaded, but it is not. Skipping.",
                                  exc_info=False)
                    requirement_met = False
                    break
        if not requirement_met:
            continue

        sorter.add(script_filename, *script_data['load_after'])

    # sort the scripts
    try:
        ordered_script = sorter.static_order()
    except CycleError:
        errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
        ordered_script = script_dependency_map.keys()

    for script_filename in ordered_script:
        script_data = script_dependency_map[script_filename]
        scripts_list.append(script_data['script_file'])


    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]


@@ -365,15 +482,9 @@ def load_scripts():
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))


    def orderby(basedir):
    # here the scripts_list is already ordered
        # 1st webui, 2nd extensions-builtin, 3rd extensions
    # processing_script is not considered though
        priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
    for scriptfile in scripts_list:
        for key in priority:
            if basedir.startswith(key):
                return priority[key]
        return 9999

    for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
        try:
        try:
            if scriptfile.basedir != paths.script_path:
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
                sys.path = [scriptfile.basedir] + sys.path