Commit d35bf649 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make launch.py run installers for extensions that have ones

add some more classes to safety module for an extension
parent f126986b
Loading
Loading
Loading
Loading
+20 −2
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import shlex
import platform

dir_repos = "repositories"
dir_extensions = "extensions"
python = sys.executable
git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
@@ -101,7 +102,22 @@ def version_check(commit):
        else:
            print("Not a git clone, can't perform version check.")
    except Exception as e:
        print("versipm check failed",e)
        print("version check failed", e)


def run_extensions_installers():
    if not os.path.isdir(dir_extensions):
        return

    for dirname_extension in os.listdir(dir_extensions):
        path_installer = os.path.join(dir_extensions, dirname_extension, "install.py")
        if not os.path.isfile(path_installer):
            continue

        try:
            print(run(f'"{python}" "{path_installer}"', errdesc=f"Error running install.py for extension {dirname_extension}"))
        except Exception as e:
            print(e, file=sys.stderr)


def prepare_enviroment():
@@ -189,6 +205,8 @@ def prepare_enviroment():

    run_pip(f"install -r {requirements_file}", "requirements for Web UI")

    run_extensions_installers()

    if update_check:
        version_check(commit)
    
+1 −1
Original line number Diff line number Diff line
@@ -32,7 +32,7 @@ class RestrictedUnpickler(pickle.Unpickler):
            return getattr(collections, name)
        if module == 'torch._utils' and name in ['_rebuild_tensor_v2', '_rebuild_parameter']:
            return getattr(torch._utils, name)
        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage']:
        if module == 'torch' and name in ['FloatStorage', 'HalfStorage', 'IntStorage', 'LongStorage', 'DoubleStorage', 'ByteStorage']:
            return getattr(torch, name)
        if module == 'torch.nn.modules.container' and name in ['ParameterDict']:
            return getattr(torch.nn.modules.container, name)