Commit 21880eb9 authored by space-nuko's avatar space-nuko
Browse files

Fix logspam and live previews

parent 12531998
Loading
Loading
Loading
Loading
+15 −5
Original line number Diff line number Diff line
@@ -19,9 +19,10 @@ class UniPCSampler(object):
                attr = attr.to(torch.device("cuda"))
        setattr(self, name, attr)

    def set_hooks(self, before, after):
        self.before_sample = before
        self.after_sample = after
    def set_hooks(self, before_sample, after_sample, after_update):
        self.before_sample = before_sample
        self.after_sample = after_sample
        self.after_update = after_update

    @torch.no_grad()
    def sample(self,
@@ -50,9 +51,17 @@ class UniPCSampler(object):
               ):
        if conditioning is not None:
            if isinstance(conditioning, dict):
                cbs = conditioning[list(conditioning.keys())[0]].shape[0]
                ctmp = conditioning[list(conditioning.keys())[0]]
                while isinstance(ctmp, list): ctmp = ctmp[0]
                cbs = ctmp.shape[0]
                if cbs != batch_size:
                    print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            elif isinstance(conditioning, list):
                for ctmp in conditioning:
                    if ctmp.shape[0] != batch_size:
                        print(f"Warning: Got {cbs} conditionings but batch-size is {batch_size}")

            else:
                if conditioning.shape[0] != batch_size:
                    print(f"Warning: Got {conditioning.shape[0]} conditionings but batch-size is {batch_size}")
@@ -60,6 +69,7 @@ class UniPCSampler(object):
        # sampling
        C, H, W = shape
        size = (batch_size, C, H, W)
        print(f'Data shape for UniPC sampling is {size}, eta {eta}')

        device = self.model.betas.device
        if x_T is None:
@@ -79,7 +89,7 @@ class UniPCSampler(object):
            guidance_scale=unconditional_guidance_scale,
        )

        uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample)
        uni_pc = UniPC(model_fn, ns, predict_x0=True, thresholding=False, condition=conditioning, unconditional_condition=unconditional_conditioning, before_sample=self.before_sample, after_sample=self.after_sample, after_update=self.after_update)
        x = uni_pc.sample(img, steps=S, skip_type="time_uniform", method="multistep", order=3, lower_order_final=True)

        return x.to(device), None
+15 −17
Original line number Diff line number Diff line
@@ -378,7 +378,8 @@ class UniPC:
        condition=None,
        unconditional_condition=None,
        before_sample=None,
        after_sample=None
        after_sample=None,
        after_update=None
    ):
        """Construct a UniPC.

@@ -394,6 +395,7 @@ class UniPC:
        self.unconditional_condition = unconditional_condition
        self.before_sample = before_sample
        self.after_sample = after_sample
        self.after_update = after_update

    def dynamic_thresholding_fn(self, x0, t=None):
        """
@@ -434,15 +436,6 @@ class UniPC:
        noise = self.noise_prediction_fn(x, t)
        dims = x.dim()
        alpha_t, sigma_t = self.noise_schedule.marginal_alpha(t), self.noise_schedule.marginal_std(t)
        from pprint import pp
        print("X:")
        pp(x)
        print("sigma_t:")
        pp(sigma_t)
        print("noise:")
        pp(noise)
        print("alpha_t:")
        pp(alpha_t)
        x0 = (x - expand_dims(sigma_t, dims) * noise) / expand_dims(alpha_t, dims)
        if self.thresholding:
            p = 0.995   # A hyperparameter in the paper of "Imagen" [1].
@@ -524,7 +517,7 @@ class UniPC:
            return self.multistep_uni_pc_vary_update(x, model_prev_list, t_prev_list, t, order, **kwargs)

    def multistep_uni_pc_vary_update(self, x, model_prev_list, t_prev_list, t, order, use_corrector=True):
        print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
        #print(f'using unified predictor-corrector with order {order} (solver type: vary coeff)')
        ns = self.noise_schedule
        assert order <= len(model_prev_list)

@@ -568,7 +561,7 @@ class UniPC:
            A_p = C_inv_p

        if use_corrector:
            print('using corrector')
            #print('using corrector')
            C_inv = torch.linalg.inv(C)
            A_c = C_inv

@@ -627,7 +620,7 @@ class UniPC:
        return x_t, model_t

    def multistep_uni_pc_bh_update(self, x, model_prev_list, t_prev_list, t, order, x_t=None, use_corrector=True):
        print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
        #print(f'using unified predictor-corrector with order {order} (solver type: B(h))')
        ns = self.noise_schedule
        assert order <= len(model_prev_list)
        dims = x.dim()
@@ -695,7 +688,7 @@ class UniPC:
            D1s = None

        if use_corrector:
            print('using corrector')
            #print('using corrector')
            # for order 1, we use a simplified version
            if order == 1:
                rhos_c = torch.tensor([0.5], device=b.device)
@@ -755,8 +748,9 @@ class UniPC:
        t_T = self.noise_schedule.T if t_start is None else t_start
        device = x.device
        if method == 'multistep':
            assert steps >= order
            assert steps >= order, "UniPC order must be < sampling steps"
            timesteps = self.get_time_steps(skip_type=skip_type, t_T=t_T, t_0=t_0, N=steps, device=device)
            print(f"Running UniPC Sampling with {timesteps.shape[0]} timesteps")
            assert timesteps.shape[0] - 1 == steps
            with torch.no_grad():
                vec_t = timesteps[0].expand((x.shape[0]))
@@ -768,6 +762,8 @@ class UniPC:
                    x, model_x = self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, init_order, use_corrector=True)
                    if model_x is None:
                        model_x = self.model_fn(x, vec_t)
                    if self.after_update is not None:
                        self.after_update(x, model_x)
                    model_prev_list.append(model_x)
                    t_prev_list.append(vec_t)
                for step in range(order, steps + 1):
@@ -776,13 +772,15 @@ class UniPC:
                        step_order = min(order, steps + 1 - step)
                    else:
                        step_order = order
                    print('this step order:', step_order)
                    #print('this step order:', step_order)
                    if step == steps:
                        print('do not run corrector at the last step')
                        #print('do not run corrector at the last step')
                        use_corrector = False
                    else:
                        use_corrector = True
                    x, model_x =  self.multistep_uni_pc_update(x, model_prev_list, t_prev_list, vec_t, step_order, use_corrector=use_corrector)
                    if self.after_update is not None:
                        self.after_update(x, model_x)
                    for i in range(order - 1):
                        t_prev_list[i] = t_prev_list[i + 1]
                        model_prev_list[i] = model_prev_list[i + 1]
+11 −9
Original line number Diff line number Diff line
@@ -103,16 +103,11 @@ class VanillaStableDiffusionSampler:

        return x, ts, cond, unconditional_conditioning

    def after_sample(self, x, ts, cond, uncond, res):
        if self.is_unipc:
            # unipc model_fn returns (pred_x0)
            # p_sample_ddim returns (x_prev, pred_x0)
            res = (None, res[0])

    def update_step(self, last_latent):
        if self.mask is not None:
            self.last_latent = self.init_latent * self.mask + self.nmask * res[1]
            self.last_latent = self.init_latent * self.mask + self.nmask * last_latent
        else:
            self.last_latent = res[1]
            self.last_latent = last_latent

        sd_samplers_common.store_latent(self.last_latent)

@@ -120,8 +115,15 @@ class VanillaStableDiffusionSampler:
        state.sampling_step = self.step
        shared.total_tqdm.update()

    def after_sample(self, x, ts, cond, uncond, res):
        if not self.is_unipc:
            self.update_step(res[1])

        return x, ts, cond, uncond, res

    def unipc_after_update(self, x, model_x):
        self.update_step(x)

    def initialize(self, p):
        self.eta = p.eta if p.eta is not None else shared.opts.eta_ddim
        if self.eta != 0.0:
@@ -131,7 +133,7 @@ class VanillaStableDiffusionSampler:
            if hasattr(self.sampler, fieldname):
                setattr(self.sampler, fieldname, self.p_sample_ddim_hook)
        if self.is_unipc:
            self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r))
            self.sampler.set_hooks(lambda x, t, c, u: self.before_sample(x, t, c, u), lambda x, t, c, u, r: self.after_sample(x, t, c, u, r), lambda x, mx: self.unipc_after_update(x, mx))

        self.mask = p.mask if hasattr(p, 'mask') else None
        self.nmask = p.nmask if hasattr(p, 'nmask') else None