Commit 8c801362 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

split commandline args into its own file

make launch.py use the same command line argument parser as the main program
parent 3ec7e19f
Loading
Loading
Loading
Loading
+23 −54
Original line number Diff line number Diff line
@@ -5,28 +5,27 @@ import sys
import importlib.util
import shlex
import platform
import argparse
import json

parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--ui-settings-file", type=str, default='config.json')
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.realpath(__file__)))
args, _ = parser.parse_known_args(sys.argv)
from modules import cmd_args
from modules.paths_internal import script_path, extensions_dir

script_path = os.path.dirname(__file__)
data_path = args.data_dir
commandline_args = os.environ.get('COMMANDLINE_ARGS', "")
sys.argv += shlex.split(commandline_args)

args, _ = cmd_args.parser.parse_known_args()

dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None
skip_install = False
dir_repos = "repositories"

if 'GRADIO_ANALYTICS_ENABLED' not in os.environ:
    os.environ['GRADIO_ANALYTICS_ENABLED'] = 'False'


def check_python_version():
    is_windows = platform.system() == "Windows"
    major = sys.version_info.major
@@ -72,23 +71,6 @@ def commit_hash():
    return stored_commit_hash


def extract_arg(args, name):
    return [x for x in args if x != name], name in args


def extract_opt(args, name):
    opt = None
    is_present = False
    if name in args:
        is_present = True
        idx = args.index(name)
        del args[idx]
        if idx < len(args) and args[idx][0] != "-":
            opt = args[idx]
            del args[idx]
    return args, is_present, opt


def run(command, desc=None, errdesc=None, custom_env=None, live=False):
    if desc is not None:
        print(desc)
@@ -225,15 +207,15 @@ def list_extensions(settings_file):

    disabled_extensions = set(settings.get('disabled_extensions', []))

    return [x for x in os.listdir(os.path.join(data_path, dir_extensions)) if x not in disabled_extensions]
    return [x for x in os.listdir(extensions_dir) if x not in disabled_extensions]


def run_extensions_installers(settings_file):
    if not os.path.isdir(dir_extensions):
    if not os.path.isdir(extensions_dir):
        return

    for dirname_extension in list_extensions(settings_file):
        run_extension_installer(os.path.join(data_path, dir_extensions, dirname_extension))
        run_extension_installer(os.path.join(extensions_dir, dirname_extension))


def prepare_environment():
@@ -241,7 +223,6 @@ def prepare_environment():

    torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.13.1+cu117 torchvision==0.14.1+cu117 --extra-index-url https://download.pytorch.org/whl/cu117")
    requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
    commandline_args = os.environ.get('COMMANDLINE_ARGS', "")

    xformers_package = os.environ.get('XFORMERS_PACKAGE', 'xformers==0.0.16rc425')
    gfpgan_package = os.environ.get('GFPGAN_PACKAGE', "git+https://github.com/TencentARC/GFPGAN.git@8d2447a2d918f8eba5a4a01463fd48e45126a379")
@@ -260,21 +241,7 @@ def prepare_environment():
    codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
    blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")

    sys.argv += shlex.split(commandline_args)

    sys.argv, _ = extract_arg(sys.argv, '-f')
    sys.argv, update_all_extensions = extract_arg(sys.argv, '--update-all-extensions')
    sys.argv, skip_torch_cuda_test = extract_arg(sys.argv, '--skip-torch-cuda-test')
    sys.argv, skip_python_version_check = extract_arg(sys.argv, '--skip-python-version-check')
    sys.argv, reinstall_xformers = extract_arg(sys.argv, '--reinstall-xformers')
    sys.argv, reinstall_torch = extract_arg(sys.argv, '--reinstall-torch')
    sys.argv, update_check = extract_arg(sys.argv, '--update-check')
    sys.argv, run_tests, test_dir = extract_opt(sys.argv, '--tests')
    sys.argv, skip_install = extract_arg(sys.argv, '--skip-install')
    xformers = '--xformers' in sys.argv
    ngrok = '--ngrok' in sys.argv

    if not skip_python_version_check:
    if not args.skip_python_version_check:
        check_python_version()

    commit = commit_hash()
@@ -282,10 +249,10 @@ def prepare_environment():
    print(f"Python {sys.version}")
    print(f"Commit hash: {commit}")

    if reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
    if args.reinstall_torch or not is_installed("torch") or not is_installed("torchvision"):
        run(f'"{python}" -m {torch_command}', "Installing torch and torchvision", "Couldn't install torch", live=True)

    if not skip_torch_cuda_test:
    if not args.skip_torch_cuda_test:
        run_python("import torch; assert torch.cuda.is_available(), 'Torch is not able to use GPU; add --skip-torch-cuda-test to COMMANDLINE_ARGS variable to disable this check'")

    if not is_installed("gfpgan"):
@@ -297,7 +264,7 @@ def prepare_environment():
    if not is_installed("open_clip"):
        run_pip(f"install {openclip_package}", "open_clip")

    if (not is_installed("xformers") or reinstall_xformers) and xformers:
    if (not is_installed("xformers") or args.reinstall_xformers) and args.xformers:
        if platform.system() == "Windows":
            if platform.python_version().startswith("3.10"):
                run_pip(f"install -U -I --no-deps {xformers_package}", "xformers")
@@ -309,7 +276,7 @@ def prepare_environment():
        elif platform.system() == "Linux":
            run_pip(f"install {xformers_package}", "xformers")

    if not is_installed("pyngrok") and ngrok:
    if not is_installed("pyngrok") and args.ngrok:
        run_pip("install pyngrok", "ngrok")

    os.makedirs(os.path.join(script_path, dir_repos), exist_ok=True)
@@ -329,18 +296,18 @@ def prepare_environment():

    run_extensions_installers(settings_file=args.ui_settings_file)

    if update_check:
    if args.update_check:
        version_check(commit)

    if update_all_extensions:
        git_pull_recursive(os.path.join(data_path, dir_extensions))
    if args.update_all_extensions:
        git_pull_recursive(extensions_dir)
    
    if "--exit" in sys.argv:
        print("Exiting because of --exit argument")
        exit(0)

    if run_tests:
        exitcode = tests(test_dir)
    if args.tests and not args.no_tests:
        exitcode = tests(args.tests)
        exit(exitcode)


@@ -354,6 +321,8 @@ def tests(test_dir):
        sys.argv.append("--skip-torch-cuda-test")
    if "--disable-nan-check" not in sys.argv:
        sys.argv.append("--disable-nan-check")
    if "--no-tests" not in sys.argv:
        sys.argv.append("--no-tests")

    print(f"Launching Web UI in another process for testing with arguments: {' '.join(sys.argv[1:])}")

+12 −661

File changed.

Preview size limit exceeded, changes collapsed.

+7 −9
Original line number Diff line number Diff line
@@ -8,11 +8,9 @@ import git
from modules import paths, shared

extensions = []
extensions_dir = os.path.join(paths.data_path, "extensions")
extensions_builtin_dir = os.path.join(paths.script_path, "extensions-builtin")

if not os.path.exists(extensions_dir):
    os.makedirs(extensions_dir)
if not os.path.exists(paths.extensions_dir):
    os.makedirs(paths.extensions_dir)

def active():
    return [x for x in extensions if x.enabled]
@@ -86,11 +84,11 @@ class Extension:
def list_extensions():
    extensions.clear()

    if not os.path.isdir(extensions_dir):
    if not os.path.isdir(paths.extensions_dir):
        return

    paths = []
    for dirname in [extensions_dir, extensions_builtin_dir]:
    extension_paths = []
    for dirname in [paths.extensions_dir, paths.extensions_builtin_dir]:
        if not os.path.isdir(dirname):
            return

@@ -99,9 +97,9 @@ def list_extensions():
            if not os.path.isdir(path):
                continue

            paths.append((extension_dirname, path, dirname == extensions_builtin_dir))
            extension_paths.append((extension_dirname, path, dirname == paths.extensions_builtin_dir))

    for dirname, path, is_builtin in paths:
    for dirname, path, is_builtin in extension_paths:
        extension = Extension(name=dirname, path=path, enabled=dirname not in shared.opts.disabled_extensions, is_builtin=is_builtin)
        extensions.append(extension)
+2 −9
Original line number Diff line number Diff line
import argparse
import os
import sys
import modules.safe
from modules.paths_internal import models_path, script_path, data_path, extensions_dir, extensions_builtin_dir

script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))
import modules.safe

# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser = argparse.ArgumentParser(add_help=False)
parser.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser.parse_known_args()[0]
data_path = cmd_opts_pre.data_dir
models_path = os.path.join(data_path, "models")

# data_path = cmd_opts_pre.data
sys.path.insert(0, script_path)
+22 −0
Original line number Diff line number Diff line
"""this module defines internal paths used by program and is safe to import before dependencies are installed in launch.py"""

import argparse
import os

script_path = os.path.dirname(os.path.dirname(os.path.realpath(__file__)))

sd_configs_path = os.path.join(script_path, "configs")
sd_default_config = os.path.join(sd_configs_path, "v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file

# Parse the --data-dir flag first so we can use it as a base for our other argument default values
parser_pre = argparse.ArgumentParser(add_help=False)
parser_pre.add_argument("--data-dir", type=str, default=os.path.dirname(os.path.dirname(os.path.realpath(__file__))), help="base path where all user data is stored",)
cmd_opts_pre = parser_pre.parse_known_args()[0]

data_path = cmd_opts_pre.data_dir

models_path = os.path.join(data_path, "models")
extensions_dir = os.path.join(data_path, "extensions")
extensions_builtin_dir = os.path.join(script_path, "extensions-builtin")
Loading