Commit a4cb96d4 authored by brkirch's avatar brkirch
Browse files

Remove test, use bool tensor fix by default

The test isn't working correctly on macOS 13.3 and the bool tensor fix for cumsum is currently always needed anyway, so enable the fix by default.
parent 27e319dc
Loading
Loading
Loading
Loading
+1 −2
Original line number Original line Diff line number Diff line
@@ -23,7 +23,7 @@ def cumsum_fix(input, cumsum_func, *args, **kwargs):
        output_dtype = kwargs.get('dtype', input.dtype)
        output_dtype = kwargs.get('dtype', input.dtype)
        if output_dtype == torch.int64:
        if output_dtype == torch.int64:
            return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
            return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
        elif cumsum_needs_bool_fix and output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
        elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
            return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
            return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
    return cumsum_func(input, *args, **kwargs)
    return cumsum_func(input, *args, **kwargs)


@@ -45,7 +45,6 @@ if has_mps:
        CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
        CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
    elif version.parse(torch.__version__) > version.parse("1.13.1"):
    elif version.parse(torch.__version__) > version.parse("1.13.1"):
        cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
        cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
        cumsum_needs_bool_fix = not torch.BoolTensor([True,True]).to(device=torch.device("mps"), dtype=torch.int64).equal(torch.BoolTensor([True,False]).to(torch.device("mps")).cumsum(0))
        cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
        cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
        CondFunc('torch.cumsum', cumsum_fix_func, None)
        CondFunc('torch.cumsum', cumsum_fix_func, None)
        CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
        CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)