Commit 37e048a7 authored by lambertae's avatar lambertae
Browse files

fix floating error

parent 15a94d6c
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
@@ -89,6 +89,7 @@ def restart_sampler(model, x, sigmas, extra_args=None, callback=None, disable=No
            restart_steps, restart_times, restart_max = restart_list[i + 1]
            min_idx = i + 1
            max_idx = int(torch.argmin(abs(sigmas - restart_max), dim=0))
            if max_idx < min_idx:
                sigma_restart = get_sigmas_karras(restart_steps, sigmas[min_idx], sigmas[max_idx], device=sigmas.device)[:-1] # remove the zero at the end
                for times in range(restart_times):
                    x = x + torch.randn_like(x) * s_noise * (sigmas[max_idx] ** 2 - sigmas[min_idx] ** 2) ** 0.5