Commit 5b2c3168 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

eliminate duplicated code from #5095

parent 997ac570
Loading
Loading
Loading
Loading
+11 −19
Original line number Diff line number Diff line
@@ -24,17 +24,18 @@ def extract_device_id(args, name):
    return None


def get_optimal_device():
    if torch.cuda.is_available():
def get_cuda_device_string():
    from modules import shared

        device_id = shared.cmd_opts.device_id
    if shared.cmd_opts.device_id is not None:
        return f"cuda:{shared.cmd_opts.device_id}"

    return "cuda"

        if device_id is not None:
            cuda_device = f"cuda:{device_id}"
            return torch.device(cuda_device)
        else:
            return torch.device("cuda")

def get_optimal_device():
    if torch.cuda.is_available():
        return torch.device(get_cuda_device_string())

    if has_mps():
        return torch.device("mps")
@@ -44,16 +45,7 @@ def get_optimal_device():

def torch_gc():
    if torch.cuda.is_available():
        from modules import shared

        device_id = shared.cmd_opts.device_id
        
        if device_id is not None:
            cuda_device = f"cuda:{device_id}"
        else:
            cuda_device = "cuda"
        
        with torch.cuda.device(cuda_device):
        with torch.cuda.device(get_cuda_device_string()):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()