Commit 180fdf78 authored by Alex "mcmonkey" Goodwin's avatar Alex "mcmonkey" Goodwin
Browse files

apply to DPM2 (non-ancestral) as well

parent 8b0703b8
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -494,7 +494,7 @@ class KDiffusionSampler:

        x = x * sigmas[0]

        if self.funcname == "sample_dpm_2_ancestral": # workaround dpm2 a issue
        if self.funcname in ['sample_dpm_2_ancestral', 'sample_dpm_2']:
            sigmas = torch.cat([sigmas[:-2], sigmas[-1:]])

        extra_params_kwargs = self.initialize(p)