Unverified Commit 73a0b4bb authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #13944 from wfjsw/dag

implementing script metadata and DAG sorting mechanism
parents 411da7c2 bde439ef
Loading
Loading
Loading
Loading
+74 −10
Original line number Original line Diff line number Diff line
import configparser
import functools
import os
import os
import threading
import threading
import re


from modules import shared, errors, cache, scripts
from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.gitpython_hack import Repo
@@ -23,8 +26,9 @@ class Extension:
    lock = threading.Lock()
    lock = threading.Lock()
    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']


    def __init__(self, name, path, enabled=True, is_builtin=False):
    def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
        self.name = name
        self.name = name
        self.canonical_name = canonical_name or name.lower()
        self.path = path
        self.path = path
        self.enabled = enabled
        self.enabled = enabled
        self.status = ''
        self.status = ''
@@ -37,6 +41,18 @@ class Extension:
        self.remote = None
        self.remote = None
        self.have_info_from_repo = False
        self.have_info_from_repo = False


    @functools.cached_property
    def metadata(self):
        if os.path.isfile(os.path.join(self.path, "metadata.ini")):
            try:
                config = configparser.ConfigParser()
                config.read(os.path.join(self.path, "metadata.ini"))
                return config
            except Exception:
                errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
                              exc_info=True)
        return None

    def to_dict(self):
    def to_dict(self):
        return {x: getattr(self, x) for x in self.cached_fields}
        return {x: getattr(self, x) for x in self.cached_fields}


@@ -56,6 +72,7 @@ class Extension:
                self.do_read_info_from_repo()
                self.do_read_info_from_repo()


                return self.to_dict()
                return self.to_dict()

        try:
        try:
            d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
            d = cache.cached_data_for_file('extensions-git', self.name, os.path.join(self.path, ".git"), read_from_repo)
            self.from_dict(d)
            self.from_dict(d)
@@ -136,9 +153,6 @@ class Extension:
def list_extensions():
def list_extensions():
    extensions.clear()
    extensions.clear()


    if not os.path.isdir(extensions_dir):
        return

    if shared.cmd_opts.disable_all_extensions:
    if shared.cmd_opts.disable_all_extensions:
        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
        print("*** \"--disable-all-extensions\" arg was used, will not load any extensions ***")
    elif shared.opts.disable_all_extensions == "all":
    elif shared.opts.disable_all_extensions == "all":
@@ -148,18 +162,68 @@ def list_extensions():
    elif shared.opts.disable_all_extensions == "extra":
    elif shared.opts.disable_all_extensions == "extra":
        print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")
        print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")


    extension_paths = []
    extension_dependency_map = {}
    for dirname in [extensions_dir, extensions_builtin_dir]:

    # scan through extensions directory and load metadata
    for dirname in [extensions_builtin_dir, extensions_dir]:
        if not os.path.isdir(dirname):
        if not os.path.isdir(dirname):
            return
            continue


        for extension_dirname in sorted(os.listdir(dirname)):
        for extension_dirname in sorted(os.listdir(dirname)):
            path = os.path.join(dirname, extension_dirname)
            path = os.path.join(dirname, extension_dirname)
            if not os.path.isdir(path):
            if not os.path.isdir(path):
                continue
                continue


            extension_paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
            canonical_name = extension_dirname
            requires = None

            if os.path.isfile(os.path.join(path, "metadata.ini")):
                try:
                    config = configparser.ConfigParser()
                    config.read(os.path.join(path, "metadata.ini"))
                    canonical_name = config.get("Extension", "Name", fallback=canonical_name)
                    requires = config.get("Extension", "Requires", fallback=None)
                except Exception:
                    errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
                                  f"Will load regardless.", exc_info=True)

            canonical_name = canonical_name.lower().strip()

            # check for duplicated canonical names
            if canonical_name in extension_dependency_map:
                errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
                              f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
                              f"The current loading extension will be discarded.", exc_info=False)
                continue


    for dirname, path, is_builtin in extension_paths:
            # both "," and " " are accepted as separator
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
            requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []

            extension_dependency_map[canonical_name] = {
                "dirname": extension_dirname,
                "path": path,
                "requires": requires,
            }

    # check for requirements
    for (_, extension_data) in extension_dependency_map.items():
        dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
        requirement_met = True
        for req in requires:
            if req not in extension_dependency_map:
                errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break
            dep_dirname = extension_dependency_map[req]['dirname']
            if dep_dirname in shared.opts.disabled_extensions:
                errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break

        is_builtin = dirname == extensions_builtin_dir
        extension = Extension(name=dirname, path=path,
                              enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
                              is_builtin=is_builtin)
        extensions.append(extension)
        extensions.append(extension)
+114 −14
Original line number Original line Diff line number Diff line
@@ -2,6 +2,7 @@ import os
import re
import re
import sys
import sys
import inspect
import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from collections import namedtuple
from dataclasses import dataclass
from dataclasses import dataclass


@@ -314,15 +315,120 @@ ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedi


def list_scripts(scriptdirname, extension, *, include_extensions=True):
def list_scripts(scriptdirname, extension, *, include_extensions=True):
    scripts_list = []
    scripts_list = []
    script_dependency_map = {}


    basedir = os.path.join(paths.script_path, scriptdirname)
    # build script dependency map
    if os.path.exists(basedir):

        for filename in sorted(os.listdir(basedir)):
    root_script_basedir = os.path.join(paths.script_path, scriptdirname)
            scripts_list.append(ScriptFile(paths.script_path, filename, os.path.join(basedir, filename)))
    if os.path.exists(root_script_basedir):
        for filename in sorted(os.listdir(root_script_basedir)):
            if not os.path.isfile(os.path.join(root_script_basedir, filename)):
                continue

            script_dependency_map[filename] = {
                "extension": None,
                "extension_dirname": None,
                "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
                "requires": [],
                "load_before": [],
                "load_after": [],
            }


    if include_extensions:
    if include_extensions:
        for ext in extensions.active():
        for ext in extensions.active():
            scripts_list += ext.list_files(scriptdirname, extension)
            extension_scripts_list = ext.list_files(scriptdirname, extension)
            for extension_script in extension_scripts_list:
                if not os.path.isfile(extension_script.path):
                    continue

                script_canonical_name = ext.canonical_name + "/" + extension_script.filename
                if ext.is_builtin:
                    script_canonical_name = "builtin/" + script_canonical_name
                relative_path = scriptdirname + "/" + extension_script.filename

                requires = ''
                load_before = ''
                load_after = ''

                if ext.metadata is not None:
                    requires = ext.metadata.get(relative_path, "Requires", fallback='')
                    load_before = ext.metadata.get(relative_path, "Before", fallback='')
                    load_after = ext.metadata.get(relative_path, "After", fallback='')

                    # propagate directory level metadata
                    requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='')
                    load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='')
                    load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='')

                requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
                load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
                load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []

                script_dependency_map[script_canonical_name] = {
                    "extension": ext.canonical_name,
                    "extension_dirname": ext.name,
                    "script_file": extension_script,
                    "requires": requires,
                    "load_before": load_before,
                    "load_after": load_after,
                }

    # resolve dependencies

    loaded_extensions = set()
    for ext in extensions.active():
        loaded_extensions.add(ext.canonical_name)

    for script_canonical_name, script_data in script_dependency_map.items():
        # load before requires inverse dependency
        # in this case, append the script name into the load_after list of the specified script
        for load_before_script in script_data['load_before']:
            # if this requires an individual script to be loaded before
            if load_before_script in script_dependency_map:
                script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
            elif load_before_script in loaded_extensions:
                for _, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_before_script:
                        script_data2['load_after'].append(script_canonical_name)
                        break

        # resolve extension name in load_after lists
        for load_after_script in list(script_data['load_after']):
            if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
                script_data['load_after'].remove(load_after_script)
                for script_canonical_name2, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_after_script:
                        script_data['load_after'].append(script_canonical_name2)
                        break

    # build the DAG
    sorter = TopologicalSorter()
    for script_canonical_name, script_data in script_dependency_map.items():
        requirement_met = True
        for required_script in script_data['requires']:
            # if this requires an individual script to be loaded
            if required_script not in script_dependency_map and required_script not in loaded_extensions:
                errors.report(f"Script \"{script_canonical_name}\" "
                              f"requires \"{required_script}\" to "
                              f"be loaded, but it is not. Skipping.",
                              exc_info=False)
                requirement_met = False
                break
        if not requirement_met:
            continue

        sorter.add(script_canonical_name, *script_data['load_after'])

    # sort the scripts
    try:
        ordered_script = sorter.static_order()
    except CycleError:
        errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
        ordered_script = script_dependency_map.keys()

    for script_canonical_name in ordered_script:
        script_data = script_dependency_map[script_canonical_name]
        scripts_list.append(script_data['script_file'])


    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]


@@ -365,15 +471,9 @@ def load_scripts():
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))


    def orderby(basedir):
    # here the scripts_list is already ordered
        # 1st webui, 2nd extensions-builtin, 3rd extensions
    # processing_script is not considered though
        priority = {os.path.join(paths.script_path, "extensions-builtin"):1, paths.script_path:0}
    for scriptfile in scripts_list:
        for key in priority:
            if basedir.startswith(key):
                return priority[key]
        return 9999

    for scriptfile in sorted(scripts_list, key=lambda x: [orderby(x.basedir), x]):
        try:
        try:
            if scriptfile.basedir != paths.script_path:
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
                sys.path = [scriptfile.basedir] + sys.path