Unverified Commit 73a0b4bb authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #13944 from wfjsw/dag

implementing script metadata and DAG sorting mechanism
parents 411da7c2 bde439ef
Loading
Loading
Loading
Loading
+74 −10
Original line number Diff line number Diff line
import configparser
import functools
import os
import threading
import re

from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
@@ -23,8 +26,9 @@ class Extension:
    lock = threading.Lock()
    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']

    def __init__(self, name, path, enabled=True, is_builtin=False):
    def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
        self.name = name
        self.canonical_name = canonical_name or name.lower()
        self.path = path
        self.enabled = enabled
        self.status = ''
@@ -37,6 +41,18 @@ class Extension:
        self.remote = None
        self.have_info_from_repo = False

    @functools.cached_property
    def metadata(self):
        if os.path.isfile(os.path.join(self.path, "metadata.ini")):
            try:
                config = configparser.ConfigParser()
                config.read(os.path.join(self.path, "metadata.ini"))
                return config
            except Exception:
                errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
                              exc_info=True)
        return None

    def to_dict(self):
        return {x: getattr(self, x) for x in self.cached_fields}

@@ -56,6 +72,7 @@ class Extension:
                self.do_read_info_from_repo()

                return self.to_dict()

        try:
            d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
            self.from_dict(d)
@@ -136,9 +153,6 @@ class Extension:
def list_extensions():
    extensions.clear()

    if not os.path.isdir(extensions_dir):
        return

    if shared.cmd_opts.disable_all_extensions:
        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
    elif shared.opts.disable_all_extensions == "all":
@@ -148,18 +162,68 @@ def list_extensions():
    elif shared.opts.disable_all_extensions == "extra":
        print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")

    extension_paths = []
    for dirname in [extensions_dir, extensions_builtin_dir]:
    extension_dependency_map = {}

    # scan through extensions directory and load metadata
    for dirname in [extensions_builtin_dir, extensions_dir]:
        if not os.path.isdir(dirname):
            return
            continue

        for extension_dirname in sorted(os.listdir(dirname)):
            path = os.path.join(dirname, extension_dirname)
            if not os.path.isdir(path):
                continue

            extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
            canonical_name = extension_dirname
            requires = None

            if os.path.isfile(os.path.join(path, "metadata.ini")):
                try:
                    config = configparser.ConfigParser()
                    config.read(os.path.join(path, "metadata.ini"))
                    canonical_name = config.get("Extension", "Name", fallback=canonical_name)
                    requires = config.get("Extension", "Requires", fallback=None)
                except Exception:
                    errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
                                  f"Will load regardless.", exc_info=True)

            canonical_name = canonical_name.lower().strip()

            # check for duplicated canonical names
            if canonical_name in extension_dependency_map:
                errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
                              f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
                              f"The current loading extension will be discarded.", exc_info=False)
                continue

    for dirname, path, is_builtin in extension_paths:
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
            # both "," and " " are accepted as separator
            requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []

            extension_dependency_map[canonical_name] = {
                "dirname": extension_dirname,
                "path": path,
                "requires": requires,
            }

    # check for requirements
    for (_, extension_data) in extension_dependency_map.items():
        dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
        requirement_met = True
        for req in requires:
            if req not in extension_dependency_map:
                errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break
            dep_dirname = extension_dependency_map[req]['dirname']
            if dep_dirname in shared.opts.disabled_extensions:
                errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break

        is_builtin = dirname == extensions_builtin_dir
        extension = Extension(name=dirname, path=path,
                              enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
                              is_builtin=is_builtin)
        extensions.append(extension)
+114 −14
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import os
import re
import sys
import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from dataclasses import dataclass

@@ -314,15 +315,120 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi

def list_scripts(scriptdirname, extension, *, include_extensions=True):
    scripts_list = []
    script_dependency_map = {}

    basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(basedir):
        for filename in sorted(os.listdir(basedir)):
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
    # build script dependency map

    root_script_basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(root_script_basedir):
        for filename in sorted(os.listdir(root_script_basedir)):
            if not os.path.isfile(os.path.join(root_script_basedir, filename)):
                continue

            script_dependency_map[filename] = {
                "extension": None,
                "extension_dirname": None,
                "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
                "requires": [],
                "load_before": [],
                "load_after": [],
            }

    if include_extensions:
        for ext in extensions.active():
            scripts_list += ext.list_files(scriptdirname, extension)
            extension_scripts_list = ext.list_files(scriptdirname, extension)
            for extension_script in extension_scripts_list:
                if not os.path.isfile(extension_script.path):
                    continue

                script_canonical_name = ext.canonical_name + "/" + extension_script.filename
                if ext.is_builtin:
                    script_canonical_name = "builtin/" + script_canonical_name
                relative_path = scriptdirname + "/" + extension_script.filename

                requires = ''
                load_before = ''
                load_after = ''

                if ext.metadata is not None:
                    requires = ext.metadata.get(relative_path, "Requires", fallback='')
                    load_before = ext.metadata.get(relative_path, "Before", fallback='')
                    load_after = ext.metadata.get(relative_path, "After", fallback='')

                    # propagate directory level metadata
                    requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='')
                    load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='')
                    load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='')

                requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
                load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
                load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []

                script_dependency_map[script_canonical_name] = {
                    "extension": ext.canonical_name,
                    "extension_dirname": ext.name,
                    "script_file": extension_script,
                    "requires": requires,
                    "load_before": load_before,
                    "load_after": load_after,
                }

    # resolve dependencies

    loaded_extensions = set()
    for ext in extensions.active():
        loaded_extensions.add(ext.canonical_name)

    for script_canonical_name, script_data in script_dependency_map.items():
        # load before requires inverse dependency
        # in this case, append the script name into the load_after list of the specified script
        for load_before_script in script_data['load_before']:
            # if this requires an individual script to be loaded before
            if load_before_script in script_dependency_map:
                script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
            elif load_before_script in loaded_extensions:
                for _, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_before_script:
                        script_data2['load_after'].append(script_canonical_name)
                        break

        # resolve extension name in load_after lists
        for load_after_script in list(script_data['load_after']):
            if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
                script_data['load_after'].remove(load_after_script)
                for script_canonical_name2, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_after_script:
                        script_data['load_after'].append(script_canonical_name2)
                        break

    # build the DAG
    sorter = TopologicalSorter()
    for script_canonical_name, script_data in script_dependency_map.items():
        requirement_met = True
        for required_script in script_data['requires']:
            # if this requires an individual script to be loaded
            if required_script not in script_dependency_map and required_script not in loaded_extensions:
                errors.report(f"Script \"{script_canonical_name}\" "
                              f"requires \"{required_script}\" to "
                              f"be loaded, but it is not. Skipping.",
                              exc_info=False)
                requirement_met = False
                break
        if not requirement_met:
            continue

        sorter.add(script_canonical_name, *script_data['load_after'])

    # sort the scripts
    try:
        ordered_script = sorter.static_order()
    except CycleError:
        errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
        ordered_script = script_dependency_map.keys()

    for script_canonical_name in ordered_script:
        script_data = script_dependency_map[script_canonical_name]
        scripts_list.append(script_data['script_file'])

    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]

@@ -365,15 +471,9 @@ def load_scripts():
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

    def orderby(basedir):
        # 1st webui, 2nd extensions-builtin, 3rd extensions
        priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
        for key in priority:
            if basedir.startswith(key):
                return priority[key]
        return 9999

    for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
    # here the scripts_list is already ordered
    # processing_script is not considered though
    for scriptfile in scripts_list:
        try:
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path