Commit a00cd8b9 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

attempt to fix memory monitor with multiple CUDA devices

parent 6033de18
Loading
Loading
Loading
Loading
+8 −4
Original line number Diff line number Diff line
@@ -23,12 +23,16 @@ class MemUsageMonitor(threading.Thread):
        self.data = defaultdict(int)

        try:
            torch.cuda.mem_get_info()
            self.cuda_mem_get_info()
            torch.cuda.memory_stats(self.device)
        except Exception as e:  # AMD or whatever
            print(f"Warning: caught exception '{e}', memory monitor disabled")
            self.disabled = True

    def cuda_mem_get_info(self):
        index = self.device.index if self.device.index is not None else torch.cuda.current_device()
        return torch.cuda.mem_get_info(index)

    def run(self):
        if self.disabled:
            return
@@ -43,10 +47,10 @@ class MemUsageMonitor(threading.Thread):
                self.run_flag.clear()
                continue

            self.data["min_free"] = torch.cuda.mem_get_info()[0]
            self.data["min_free"] = self.cuda_mem_get_info()[0]

            while self.run_flag.is_set():
                free, total = torch.cuda.mem_get_info()  # calling with self.device errors, torch bug?
                free, total = self.cuda_mem_get_info()
                self.data["min_free"] = min(self.data["min_free"], free)

                time.sleep(1 / self.opts.memmon_poll_rate)
@@ -70,7 +74,7 @@ class MemUsageMonitor(threading.Thread):

    def read(self):
        if not self.disabled:
            free, total = torch.cuda.mem_get_info()
            free, total = self.cuda_mem_get_info()
            self.data["free"] = free
            self.data["total"] = total