Unverified Commit ea05ddfe authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #10201 from brkirch/mps-nan-fixes

Fix MPS on PyTorch 2.0.1, Intel Macs
parents 2b96a7b6 de401d8f
Loading
Loading
Loading
Loading
+7 −2
Original line number Diff line number Diff line
@@ -54,6 +54,11 @@ if has_mps:
        CondFunc('torch.cumsum', cumsum_fix_func, None)
        CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
        CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)
    if version.parse(torch.__version__) == version.parse("2.0"):

        # MPS workaround for https://github.com/pytorch/pytorch/issues/96113
        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda *args, **kwargs: len(args) == 6)
        CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')

        # MPS workaround for https://github.com/pytorch/pytorch/issues/92311
        if platform.processor() == 'i386':
            for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
                CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
 No newline at end of file
+3 −0
Original line number Diff line number Diff line
@@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)

    if q.device.type == 'mps':
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()

    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k = q.float(), k.float()