Commit c62d17ae authored by AUTOMATIC's avatar AUTOMATIC
Browse files

use the new devices.has_mps() function in register_buffer for DDIM/PLMS fix for OSX

parent 526f0aa5
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
@@ -418,8 +418,7 @@ def register_buffer(self, name, attr):
    if type(attr) == torch.Tensor:
        if attr.device != devices.device:

            # would this not break cuda when torch adds has_mps() to main version?
            if getattr(torch, 'has_mps', False):
            if devices.has_mps():
                attr = attr.to(device="mps", dtype=torch.float32)
            else:
                attr = attr.to(devices.device)