Commit 15e89ef0 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix for unet hijack breaking the train tab

parent 789d47f8
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -36,8 +36,11 @@ th = TorchHijackForUnet()

# Below are monkey patches to enable upcasting a float16 UNet for float32 sampling
def apply_model(orig_func, self, x_noisy, t, cond, **kwargs):

    if isinstance(cond, dict):
        for y in cond.keys():
            cond[y] = [x.to(devices.dtype_unet) if isinstance(x, torch.Tensor) else x for x in cond[y]]

    with devices.autocast():
        return orig_func(self, x_noisy.to(devices.dtype_unet), t.to(devices.dtype_unet), cond, **kwargs).float()