Commit b85fc718 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Fix MPS cache cleanup

Importing torch does not import torch.mps so the call failed.
parent 7b833291
Loading
Loading
Loading
Loading
+3 −2
Original line number Diff line number Diff line
@@ -54,8 +54,9 @@ def torch_gc():
        with torch.cuda.device(get_cuda_device_string()):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
    elif has_mps() and hasattr(torch.mps, 'empty_cache'):
        torch.mps.empty_cache()

    if has_mps():
        mac_specific.torch_mps_gc()


def enable_tf32():
+14 −0
Original line number Diff line number Diff line
import logging

import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version

log = logging.getLogger()


# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
            return False
    else:
        return torch.backends.mps.is_available() and torch.backends.mps.is_built()


has_mps = check_for_mps()


def torch_mps_gc() -> None:
    try:
        from torch.mps import empty_cache
        empty_cache()
    except Exception:
        log.warning("MPS garbage collection failed", exc_info=True)


# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
    if input.device.type == 'mps':