Commit b85fc718 authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Fix MPS cache cleanup

Importing torch does not import torch.mps so the call failed.
parent 7b833291
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
@@ -54,8 +54,9 @@ def torch_gc():
        with torch.cuda.device(get_cuda_device_string()):
        with torch.cuda.device(get_cuda_device_string()):
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.ipc_collect()
    elif has_mps() and hasattr(torch.mps, 'empty_cache'):

        torch.mps.empty_cache()
    if has_mps():
        mac_specific.torch_mps_gc()




def enable_tf32():
def enable_tf32():
+14 −0
Original line number Original line Diff line number Diff line
import logging

import torch
import torch
import platform
import platform
from modules.sd_hijack_utils import CondFunc
from modules.sd_hijack_utils import CondFunc
from packaging import version
from packaging import version


log = logging.getLogger()



# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
            return False
            return False
    else:
    else:
        return torch.backends.mps.is_available() and torch.backends.mps.is_built()
        return torch.backends.mps.is_available() and torch.backends.mps.is_built()


has_mps = check_for_mps()
has_mps = check_for_mps()




def torch_mps_gc() -> None:
    try:
        from torch.mps import empty_cache
        empty_cache()
    except Exception:
        log.warning("MPS garbage collection failed", exc_info=True)


# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
def cumsum_fix(input, cumsum_func, *args, **kwargs):
    if input.device.type == 'mps':
    if input.device.type == 'mps':