Unverified Commit 799760ab authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #11722 from akx/mps-gc-fix

Fix MPS cache cleanup
parents 7b833291 b85fc718
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
@@ -54,8 +54,9 @@ def torch_gc():
        with torch.cuda.device(get_cuda_device_string()):
        with torch.cuda.device(get_cuda_device_string()):
            torch.cuda.empty_cache()
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
            torch.cuda.ipc_collect()
    elif has_mps() and hasattr(torch.mps, 'empty_cache'):

        torch.mps.empty_cache()
    if has_mps():
        mac_specific.torch_mps_gc()




def enable_tf32():
def enable_tf32():
+14 −0
Original line number Original line Diff line number Diff line
import logging

import torch
import torch
import platform
import platform
from modules.sd_hijack_utils import CondFunc
from modules.sd_hijack_utils import CondFunc
from packaging import version
from packaging import version


log = logging.getLogger()



# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,19 @@ def check_for_mps() -> bool:
            return False
            return False
    else:
    else:
        return torch.backends.mps.is_available() and torch.backends.mps.is_built()
        return torch.backends.mps.is_available() and torch.backends.mps.is_built()


has_mps = check_for_mps()
has_mps = check_for_mps()




def torch_mps_gc() -> None:
    try:
        from torch.mps import empty_cache
        empty_cache()
    except Exception:
        log.warning("MPS garbage collection failed", exc_info=True)


# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
def cumsum_fix(input, cumsum_func, *args, **kwargs):
    if input.device.type == 'mps':
    if input.device.type == 'mps':