Commit 98ca437e authored by brkirch's avatar brkirch
Browse files

Refactor and instead check if mps is being used, not availability

parent 0b5dcb3d
Loading
Loading
Loading
Loading
+1 −5
Original line number Diff line number Diff line
@@ -182,11 +182,7 @@ def register_buffer(self, name, attr):

    if type(attr) == torch.Tensor:
        if attr.device != devices.device:

            if devices.has_mps():
                attr = attr.to(device="mps", dtype=torch.float32)
            else:
                attr = attr.to(devices.device)
            attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))

    setattr(self, name, attr)