Commit 7aab389d authored by brkirch's avatar brkirch
Browse files

Fix for Unet NaNs

parent 5ab7f213
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -256,6 +256,9 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)

    if q.device.type == 'mps':
        q, k, v = q.contiguous(), k.contiguous(), v.contiguous()

    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k = q.float(), k.float()