Commit c67c40f9 authored by Matthew McGoogan's avatar Matthew McGoogan
Browse files

torch.cuda.empty_cache() defaults to cuda:0 device unless explicitly set...

torch.cuda.empty_cache() defaults to cuda:0 device unless explicitly set otherwise first. Updating torch_gc() to use the device set by --device-id if specified to avoid OOM edge cases on multi-GPU systems.
parent b5050ad2
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -44,6 +44,16 @@ def get_optimal_device():

def torch_gc():
    if torch.cuda.is_available():
        from modules import shared

        device_id = shared.cmd_opts.device_id
        
        if device_id is not None:
            cuda_device = f"cuda:{device_id}"
        else:
            cuda_device = "cuda"
        
        with torch.cuda.device(cuda_device):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()