Commit 1463cea9 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

Merge branch 'dag' into dev

parents 73a0b4bb 9b471436
Loading
Loading
Loading
Loading
+70 −62
Original line number Diff line number Diff line
from __future__ import annotations

import configparser
import functools
import os
import threading
import re
@@ -8,7 +9,6 @@ from modules import shared, errors, cache, scripts
from modules.gitpython_hack import Repo
from modules.paths_internal import extensions_dir, extensions_builtin_dir, script_path  # noqa: F401

extensions = []

os.makedirs(extensions_dir, exist_ok=True)

@@ -22,13 +22,56 @@ def active():
        return [x for x in extensions if x.enabled]


class ExtensionMetadata:
    filename = "metadata.ini"
    config: configparser.ConfigParser
    canonical_name: str
    requires: list

    def __init__(self, path, canonical_name):
        self.config = configparser.ConfigParser()

        filepath = os.path.join(path, self.filename)
        if os.path.isfile(filepath):
            try:
                self.config.read(filepath)
            except Exception:
                errors.report(f"Error reading {self.filename} for extension {canonical_name}.", exc_info=True)

        self.canonical_name = self.config.get("Extension", "Name", fallback=canonical_name)
        self.canonical_name = canonical_name.lower().strip()

        self.requires = self.get_script_requirements("Requires", "Extension")

    def get_script_requirements(self, field, section, extra_section=None):
        """reads a list of requirements from the config; field is the name of the field in the ini file,
        like Requires or Before, and section is the name of the [section] in the ini file; additionally,
        reads more requirements from [extra_section] if specified."""

        x = self.config.get(section, field, fallback='')

        if extra_section:
            x = x + ', ' + self.config.get(extra_section, field, fallback='')

        return self.parse_list(x.lower())

    def parse_list(self, text):
        """converts a line from config ("ext1 ext2, ext3  ") into a python list (["ext1", "ext2", "ext3"])"""

        if not text:
            return []

        # both "," and " " are accepted as separator
        return [x for x in re.split(r"[,\s]+", text.strip()) if x]


class Extension:
    lock = threading.Lock()
    cached_fields = ['remote', 'commit_date', 'branch', 'commit_hash', 'version']
    metadata: ExtensionMetadata

    def __init__(self, name, path, enabled=True, is_builtin=False, canonical_name=None):
    def __init__(self, name, path, enabled=True, is_builtin=False, metadata=None):
        self.name = name
        self.canonical_name = canonical_name or name.lower()
        self.path = path
        self.enabled = enabled
        self.status = ''
@@ -40,18 +83,8 @@ class Extension:
        self.branch = None
        self.remote = None
        self.have_info_from_repo = False

    @functools.cached_property
    def metadata(self):
        if os.path.isfile(os.path.join(self.path, "metadata.ini")):
            try:
                config = configparser.ConfigParser()
                config.read(os.path.join(self.path, "metadata.ini"))
                return config
            except Exception:
                errors.report(f"Error reading metadata.ini for extension {self.canonical_name}.",
                              exc_info=True)
        return None
        self.metadata = metadata if metadata else ExtensionMetadata(self.path, name.lower())
        self.canonical_name = metadata.canonical_name

    def to_dict(self):
        return {x: getattr(self, x) for x in self.cached_fields}
@@ -162,7 +195,7 @@ def list_extensions():
    elif shared.opts.disable_all_extensions == "extra":
        print("*** \"Disable all extensions\" option was set, will only load built-in extensions ***")

    extension_dependency_map = {}
    loaded_extensions = {}

    # scan through extensions directory and load metadata
    for dirname in [extensions_builtin_dir, extensions_dir]:
@@ -175,55 +208,30 @@ def list_extensions():
                continue

            canonical_name = extension_dirname
            requires = None

            if os.path.isfile(os.path.join(path, "metadata.ini")):
                try:
                    config = configparser.ConfigParser()
                    config.read(os.path.join(path, "metadata.ini"))
                    canonical_name = config.get("Extension", "Name", fallback=canonical_name)
                    requires = config.get("Extension", "Requires", fallback=None)
                except Exception:
                    errors.report(f"Error reading metadata.ini for extension {extension_dirname}. "
                                  f"Will load regardless.", exc_info=True)

            canonical_name = canonical_name.lower().strip()
            metadata = ExtensionMetadata(path, canonical_name)

            # check for duplicated canonical names
            if canonical_name in extension_dependency_map:
                errors.report(f"Duplicate canonical name \"{canonical_name}\" found in extensions "
                              f"\"{extension_dirname}\" and \"{extension_dependency_map[canonical_name]['dirname']}\". "
                              f"The current loading extension will be discarded.", exc_info=False)
            already_loaded_extension = loaded_extensions.get(metadata.canonical_name)
            if already_loaded_extension is not None:
                errors.report(f'Duplicate canonical name "{canonical_name}" found in extensions "{extension_dirname}" and "{already_loaded_extension.name}". Former will be discarded.', exc_info=False)
                continue

            # both "," and " " are accepted as separator
            requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []

            extension_dependency_map[canonical_name] = {
                "dirname": extension_dirname,
                "path": path,
                "requires": requires,
            }
            is_builtin = dirname == extensions_builtin_dir
            extension = Extension(name=extension_dirname, path=path, enabled=extension_dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin, metadata=metadata)
            extensions.append(extension)
            loaded_extensions[canonical_name] = extension

    # check for requirements
    for (_, extension_data) in extension_dependency_map.items():
        dirname, path, requires = extension_data['dirname'], extension_data['path'], extension_data['requires']
        requirement_met = True
        for req in requires:
            if req not in extension_dependency_map:
                errors.report(f"Extension \"{dirname}\" requires \"{req}\" which is not installed. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break
            dep_dirname = extension_dependency_map[req]['dirname']
            if dep_dirname in shared.opts.disabled_extensions:
                errors.report(f"Extension \"{dirname}\" requires \"{dep_dirname}\" which is disabled. "
                              f"The current loading extension will be discarded.", exc_info=False)
                requirement_met = False
                break
    for extension in extensions:
        for req in extension.metadata.requires:
            required_extension = loaded_extensions.get(req)
            if required_extension is None:
                errors.report(f'Extension "{extension.name}" requires "{req}" which is not installed.', exc_info=False)
                continue

        is_builtin = dirname == extensions_builtin_dir
        extension = Extension(name=dirname, path=path,
                              enabled=dirname not in shared.opts.disabled_extensions and requirement_met,
                              is_builtin=is_builtin)
        extensions.append(extension)
            if not extension.enabled:
                errors.report(f'Extension "{extension.name}" requires "{required_extension.name}" which is disabled.', exc_info=False)
                continue


extensions: list[Extension] = []
+78 −91
Original line number Diff line number Diff line
@@ -2,7 +2,6 @@ import os
import re
import sys
import inspect
from graphlib import TopologicalSorter, CycleError
from collections import namedtuple
from dataclasses import dataclass

@@ -312,27 +311,57 @@ scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])

def topological_sort(dependencies):
    """Accepts a dictionary mapping name to its dependencies, returns a list of names ordered according to dependencies.
    Ignores errors relating to missing dependeencies or circular dependencies
    """

    visited = {}
    result = []

    def inner(name):
        visited[name] = True

        for dep in dependencies.get(name, []):
            if dep in dependencies and dep not in visited:
                inner(dep)

        result.append(name)

    for depname in dependencies:
        if depname not in visited:
            inner(depname)

    return result


@dataclass
class ScriptWithDependencies:
    script_canonical_name: str
    file: ScriptFile
    requires: list
    load_before: list
    load_after: list


def list_scripts(scriptdirname, extension, *, include_extensions=True):
    scripts_list = []
    script_dependency_map = {}
    scripts = {}

    # build script dependency map
    loaded_extensions = {ext.canonical_name: ext for ext in extensions.active()}
    loaded_extensions_scripts = {ext.canonical_name: [] for ext in extensions.active()}

    # build script dependency map
    root_script_basedir = os.path.join(paths.script_path, scriptdirname)
    if os.path.exists(root_script_basedir):
        for filename in sorted(os.listdir(root_script_basedir)):
            if not os.path.isfile(os.path.join(root_script_basedir, filename)):
                continue

            script_dependency_map[filename] = {
                "extension": None,
                "extension_dirname": None,
                "script_file": ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename)),
                "requires": [],
                "load_before": [],
                "load_after": [],
            }
            if os.path.splitext(filename)[1].lower() != extension:
                continue

            script_file = ScriptFile(paths.script_path, filename, os.path.join(root_script_basedir, filename))
            scripts[filename] = ScriptWithDependencies(filename, script_file, [], [], [])

    if include_extensions:
        for ext in extensions.active():
@@ -341,96 +370,54 @@ def list_scripts(scriptdirname, extension, *, include_extensions=True):
                if not os.path.isfile(extension_script.path):
                    continue

                script_canonical_name = ext.canonical_name + "/" + extension_script.filename
                if ext.is_builtin:
                    script_canonical_name = "builtin/" + script_canonical_name
                script_canonical_name = ("builtin/" if ext.is_builtin else "") + ext.canonical_name + "/" + extension_script.filename
                relative_path = scriptdirname + "/" + extension_script.filename

                requires = ''
                load_before = ''
                load_after = ''

                if ext.metadata is not None:
                    requires = ext.metadata.get(relative_path, "Requires", fallback='')
                    load_before = ext.metadata.get(relative_path, "Before", fallback='')
                    load_after = ext.metadata.get(relative_path, "After", fallback='')

                    # propagate directory level metadata
                    requires = requires + ',' + ext.metadata.get(scriptdirname, "Requires", fallback='')
                    load_before = load_before + ',' + ext.metadata.get(scriptdirname, "Before", fallback='')
                    load_after = load_after + ',' + ext.metadata.get(scriptdirname, "After", fallback='')

                requires = list(filter(None, re.split(r"[,\s]+", requires.lower()))) if requires else []
                load_after = list(filter(None, re.split(r"[,\s]+", load_after.lower()))) if load_after else []
                load_before = list(filter(None, re.split(r"[,\s]+", load_before.lower()))) if load_before else []

                script_dependency_map[script_canonical_name] = {
                    "extension": ext.canonical_name,
                    "extension_dirname": ext.name,
                    "script_file": extension_script,
                    "requires": requires,
                    "load_before": load_before,
                    "load_after": load_after,
                }
                script = ScriptWithDependencies(
                    script_canonical_name=script_canonical_name,
                    file=extension_script,
                    requires=ext.metadata.get_script_requirements("Requires", relative_path, scriptdirname),
                    load_before=ext.metadata.get_script_requirements("Before", relative_path, scriptdirname),
                    load_after=ext.metadata.get_script_requirements("After", relative_path, scriptdirname),
                )

    # resolve dependencies
                scripts[script_canonical_name] = script
                loaded_extensions_scripts[ext.canonical_name].append(script)

    loaded_extensions = set()
    for ext in extensions.active():
        loaded_extensions.add(ext.canonical_name)

    for script_canonical_name, script_data in script_dependency_map.items():
    for script_canonical_name, script in scripts.items():
        # load before requires inverse dependency
        # in this case, append the script name into the load_after list of the specified script
        for load_before_script in script_data['load_before']:
        for load_before in script.load_before:
            # if this requires an individual script to be loaded before
            if load_before_script in script_dependency_map:
                script_dependency_map[load_before_script]['load_after'].append(script_canonical_name)
            elif load_before_script in loaded_extensions:
                for _, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_before_script:
                        script_data2['load_after'].append(script_canonical_name)
                        break

        # resolve extension name in load_after lists
        for load_after_script in list(script_data['load_after']):
            if load_after_script not in script_dependency_map and load_after_script in loaded_extensions:
                script_data['load_after'].remove(load_after_script)
                for script_canonical_name2, script_data2 in script_dependency_map.items():
                    if script_data2['extension'] == load_after_script:
                        script_data['load_after'].append(script_canonical_name2)
                        break

    # build the DAG
    sorter = TopologicalSorter()
    for script_canonical_name, script_data in script_dependency_map.items():
        requirement_met = True
        for required_script in script_data['requires']:
            # if this requires an individual script to be loaded
            if required_script not in script_dependency_map and required_script not in loaded_extensions:
                errors.report(f"Script \"{script_canonical_name}\" "
                              f"requires \"{required_script}\" to "
                              f"be loaded, but it is not. Skipping.",
                              exc_info=False)
                requirement_met = False
                break
        if not requirement_met:
            continue
            other_script = scripts.get(load_before)
            if other_script:
                other_script.load_after.append(script_canonical_name)

        sorter.add(script_canonical_name, *script_data['load_after'])
            # if this requires an extension
            other_extension_scripts = loaded_extensions_scripts.get(load_before)
            if other_extension_scripts:
                for other_script in other_extension_scripts:
                    other_script.load_after.append(script_canonical_name)

    # sort the scripts
    try:
        ordered_script = sorter.static_order()
    except CycleError:
        errors.report("Cycle detected in script dependencies. Scripts will load in ascending order.", exc_info=True)
        ordered_script = script_dependency_map.keys()
        # if After mentions an extension, remove it and instead add all of its scripts
        for load_after in list(script.load_after):
            if load_after not in scripts and load_after in loaded_extensions_scripts:
                script.load_after.remove(load_after)

                for other_script in loaded_extensions_scripts.get(load_after, []):
                    script.load_after.append(other_script.script_canonical_name)

    dependencies = {}

    for script_canonical_name, script in scripts.items():
        for required_script in script.requires:
            if required_script not in scripts and required_script not in loaded_extensions:
                errors.report(f'Script "{script_canonical_name}" requires "{required_script}" to be loaded, but it is not.', exc_info=False)

    for script_canonical_name in ordered_script:
        script_data = script_dependency_map[script_canonical_name]
        scripts_list.append(script_data['script_file'])
        dependencies[script_canonical_name] = script.load_after

    scripts_list = [x for x in scripts_list if os.path.splitext(x.path)[1].lower() == extension and os.path.isfile(x.path)]
    ordered_scripts = topological_sort(dependencies)
    scripts_list = [scripts[script_canonical_name].file for script_canonical_name in ordered_scripts]

    return scripts_list