Commit ae17e978 authored by Sakura-Luna's avatar Sakura-Luna
Browse files

UniPC progress bar adjustment

parent 22bcc7be
Loading
Loading
Loading
Loading
+37 −33
Original line number Diff line number Diff line
import torch
import torch.nn.functional as F
import math
from tqdm.auto import trange
import tqdm


class NoiseScheduleVP:
@@ -757,6 +757,7 @@ class UniPC:
                vec_t = timesteps[0].expand((x.shape[0]))
                model_prev_list = [self.model_fn(x, vec_t)]
                t_prev_list = [vec_t]
                with tqdm.tqdm(total=steps) as pbar:
                    # Init the first `order` values by lower order multistep DPM-Solver.
                    for init_order in range(1, order):
                        vec_t = timesteps[init_order].expand(x.shape[0])
@@ -767,7 +768,9 @@ class UniPC:
                            self.after_update(x, model_x)
                        model_prev_list.append(model_x)
                        t_prev_list.append(vec_t)
                for step in trange(order, steps + 1):
                        pbar.update()

                    for step in range(order, steps + 1):
                        vec_t = timesteps[step].expand(x.shape[0])
                        if lower_order_final:
                            step_order = min(order, steps + 1 - step)
@@ -791,6 +794,7 @@ class UniPC:
                            if model_x is None:
                                model_x = self.model_fn(x, vec_t)
                            model_prev_list[-1] = model_x
                        pbar.update()
        else:
            raise NotImplementedError()
        if denoise_to_zero: