Commit 45dca056 authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur
Browse files

Merge branch 'a1111' into vae-fix-none

parents 028b67b6 d9fd4525
Loading
Loading
Loading
Loading
+15 −11
Original line number Diff line number Diff line
@@ -105,24 +105,28 @@ def version_check(commit):
        print("version check failed", e)


def run_extensions_installers():
    if not os.path.isdir(dir_extensions):
        return

    for dirname_extension in os.listdir(dir_extensions):
        path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
def run_extension_installer(extension_dir):
    path_installer = os.path.join(extension_dir, "install.py")
    if not os.path.isfile(path_installer):
            continue
        return

    try:
        env = os.environ.copy()
        env['PYTHONPATH'] = os.path.abspath(".")

            print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}", custom_env=env))
        print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {extension_dir}", custom_env=env))
    except Exception as e:
        print(e, file=sys.stderr)


def run_extensions_installers():
    if not os.path.isdir(dir_extensions):
        return

    for dirname_extension in os.listdir(dir_extensions):
        run_extension_installer(os.path.join(dir_extensions, dirname_extension))


def prepare_enviroment():
    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
    requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
+0 −21
Original line number Diff line number Diff line
import os
import sys
import traceback
from importlib.machinery import SourceFileLoader

import git

@@ -85,23 +84,3 @@ def list_extensions():
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions)
        extensions.append(extension)

def preload_extensions(parser):
    if not os.path.isdir(extensions_dir):
        return

    for dirname in sorted(os.listdir(extensions_dir)):
        path = os.path.join(extensions_dir, dirname)
        if not os.path.isdir(path):
            continue
        for file in os.listdir(path):
            if "preload.py" in file:
                full_file = os.path.join(path, file)
                print(f"Got preload file: {full_file}")

                try:
                    ext = SourceFileLoader("preload", full_file).load_module()
                    parser = ext.preload(parser)
                except Exception as e:
                    print(f"Exception preloading script: {e}")
    return parser
 No newline at end of file
+34 −0
Original line number Diff line number Diff line
import os
import sys
import traceback
from types import ModuleType


def load_module(path):
    with open(path, "r", encoding="utf8") as file:
        text = file.read()

    compiled = compile(text, path, 'exec')
    module = ModuleType(os.path.basename(path))
    exec(compiled, module.__dict__)

    return module


def preload_extensions(extensions_dir, parser):
    if not os.path.isdir(extensions_dir):
        return

    for dirname in sorted(os.listdir(extensions_dir)):
        preload_script = os.path.join(extensions_dir, dirname, "preload.py")
        if not os.path.isfile(preload_script):
            continue

        try:
            module = load_module(preload_script)
            if hasattr(module, 'preload'):
                module.preload(parser)

        except Exception:
            print(f"Error running preload() for {preload_script}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
+17 −29
Original line number Diff line number Diff line
@@ -6,7 +6,7 @@ from collections import namedtuple
import gradio as gr

from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions
from modules import shared, paths, script_callbacks, extensions, script_loading

AlwaysVisible = object()

@@ -161,13 +161,7 @@ def load_scripts():
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

            with open(scriptfile.path, "r", encoding="utf8") as file:
                text = file.read()

            from types import ModuleType
            compiled = compile(text, scriptfile.path, 'exec')
            module = ModuleType(scriptfile.filename)
            exec(compiled, module.__dict__)
            module = script_loading.load_module(scriptfile.path)

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
@@ -328,19 +322,13 @@ class ScriptRunner:

    def reload_sources(self, cache):
        for si, script in list(enumerate(self.scripts)):
            with open(script.filename, "r", encoding="utf8") as file:
            args_from = script.args_from
            args_to = script.args_to
            filename = script.filename
                text = file.read()

                from types import ModuleType

            module = cache.get(filename, None)
            if module is None:
                    compiled = compile(text, filename, 'exec')
                    module = ModuleType(script.filename)
                    exec(compiled, module.__dict__)
                module = script_loading.load_module(script.filename)
                cache[filename] = module

            for key, script_class in module.__dict__.items():
+16 −9
Original line number Diff line number Diff line
@@ -87,7 +87,19 @@ def refresh_vae_list(vae_path=vae_path, model_path=model_path):
    return vae_list


def resolve_vae(checkpoint_file, vae_file="auto"):
def get_vae_from_settings(vae_file="auto"):
    # else, we load from settings, if not set to be default
    if vae_file == "auto" and shared.opts.sd_vae is not None:
        # if saved VAE settings isn't recognized, fallback to auto
        vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
        # if VAE selected but not found, fallback to auto
        if vae_file not in default_vae_values and not os.path.isfile(vae_file):
            vae_file = "auto"
            print("Selected VAE doesn't exist")
    return vae_file


def resolve_vae(checkpoint_file=None, vae_file="auto"):
    global first_load, vae_dict, vae_list

    # if vae_file argument is provided, it takes priority, but not saved
@@ -102,14 +114,9 @@ def resolve_vae(checkpoint_file, vae_file="auto"):
            shared.opts.data['sd_vae'] = get_filename(vae_file)
        else:
            print("VAE provided as command line argument doesn't exist")
    # else, we load from settings
    if vae_file == "auto" and shared.opts.sd_vae is not None:
        # if saved VAE settings isn't recognized, fallback to auto
        vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
        # if VAE selected but not found, fallback to auto
        if vae_file not in default_vae_values and not os.path.isfile(vae_file):
            vae_file = "auto"
            print("Selected VAE doesn't exist")
    # fallback to selector in settings, if vae selector not set to act as default fallback
    if not shared.opts.sd_vae_as_default:
        vae_file = get_vae_from_settings(vae_file)
    # vae-path cmd arg takes priority for auto
    if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
        if os.path.isfile(shared.cmd_opts.vae_path):
Loading