Commit 48f4abd2 authored by EllangoK's avatar EllangoK
Browse files

fix dims typo in unipc

parent 27e319dc
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -719,7 +719,7 @@ class UniPC:
                x_t = x_t_ - expand_dims(alpha_t * B_h, dims) * (corr_res + rhos_c[-1] * D1_t)
        else:
            x_t_ = (
                expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dimss) * x
                expand_dims(torch.exp(log_alpha_t - log_alpha_prev_0), dims) * x
                - expand_dims(sigma_t * h_phi_1, dims) * model_prev_0
            )
            if x_t is None: