Commit e247b740 authored by brkirch's avatar brkirch
Browse files

Add fixes for PyTorch 1.12.1

Fix typo "MasOS" -> "macOS"

If MPS is available and PyTorch is an earlier version than 1.13:
* Monkey patch torch.Tensor.to to ensure all tensors sent to MPS are contiguous
* Monkey patch torch.nn.functional.layer_norm to ensure input tensor is contiguous (required for this program to work with MPS on unmodified PyTorch 1.12.1)
parent a5106a7c
Loading
Loading
Loading
Loading
+27 −1
Original line number Diff line number Diff line
@@ -2,9 +2,10 @@ import sys, os, shlex
import contextlib
import torch
from modules import errors
from packaging import version


# has_mps is only available in nightly pytorch (for now) and MasOS 12.3+.
# has_mps is only available in nightly pytorch (for now) and macOS 12.3+.
# check `getattr` and try it for compatibility
def has_mps() -> bool:
    if not getattr(torch, 'has_mps', False):
@@ -94,3 +95,28 @@ def autocast(disable=False):
        return contextlib.nullcontext()

    return torch.autocast("cuda")


# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
orig_tensor_to = torch.Tensor.to
def tensor_to_fix(self, *args, **kwargs):
    if self.device.type != 'mps' and \
       ((len(args) > 0 and isinstance(args[0], torch.device) and args[0].type == 'mps') or \
       (isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')):
        self = self.contiguous()
    return orig_tensor_to(self, *args, **kwargs)


# MPS workaround for https://github.com/pytorch/pytorch/issues/80800 
orig_layer_norm = torch.nn.functional.layer_norm
def layer_norm_fix(*args, **kwargs):
    if len(args) > 0 and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps':
        args = list(args)
        args[0] = args[0].contiguous()
    return orig_layer_norm(*args, **kwargs)


# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working
if has_mps() and version.parse(torch.__version__) < version.parse("1.13"):
    torch.Tensor.to = tensor_to_fix
    torch.nn.functional.layer_norm = layer_norm_fix