Unverified Commit a407c9f0 authored by DaniAndTheWeb's avatar DaniAndTheWeb Committed by GitHub
Browse files

Automatic torch install for amd on linux

This commit allows the launch script to automatically download rocm's torch version for AMD GPUs using an external GPU detection script. It also prints the operative system and GPU in use.
parent eaebcf63
Loading
Loading
Loading
Loading
+14 −1
Original line number Diff line number Diff line
@@ -7,6 +7,7 @@ import shlex
import platform
import argparse
import json
import detection

dir_repos = "repositories"
dir_extensions = "extensions"
@@ -15,6 +16,12 @@ git = os.environ.get('GIT', "git")
index_url = os.environ.get('INDEX_URL', "")
stored_commit_hash = None

# Get the GPU vendor and the operating system
gpu = detection.check_gpu()
if os.name == "posix":
    os_name = platform.uname().system
else:
    os_name = os.name

def commit_hash():
    global stored_commit_hash
@@ -173,7 +180,11 @@ def run_extensions_installers(settings_file):


def prepare_environment():
    if gpu == "AMD" and os_name !="nt":
        torch_command = os.environ.get('TORCH_COMMAND', "pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/rocm5.2")
    else:
        torch_command = os.environ.get('TORCH_COMMAND', "pip install torch==1.12.1+cu113 torchvision==0.13.1+cu113 --extra-index-url https://download.pytorch.org/whl/cu113")
    
    requirements_file = os.environ.get('REQS_FILE', "requirements_versions.txt")
    commandline_args = os.environ.get('COMMANDLINE_ARGS', "")

@@ -295,6 +306,8 @@ def tests(test_dir):


def start():
    print(f"Operating System: {os_name}")
    print(f"GPU: {gpu}")
    print(f"Launching {'API server' if '--nowebui' in sys.argv else 'Web UI'} with arguments: {' '.join(sys.argv[1:])}")
    import webui
    if '--nowebui' in sys.argv: