Commit df3b31eb authored by brkirch's avatar brkirch
Browse files

In-place operations can break gradient calculation

parent 15123339
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -247,9 +247,9 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
        # restoring original mean is likely not correct, but it seems to work well to prevent artifacts that happen otherwise
        batch_multipliers = torch.asarray(batch_multipliers).to(devices.device)
        original_mean = z.mean()
        z *= batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
        z = z * batch_multipliers.reshape(batch_multipliers.shape + (1,)).expand(z.shape)
        new_mean = z.mean()
        z *= original_mean / new_mean
        z = z * (original_mean / new_mean)

        return z