Unverified Commit 67efee33 authored by klimaleksus's avatar klimaleksus Committed by GitHub
Browse files

Make VAE step sequential to prevent VRAM spikes

parent 0b5dcb3d
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -530,8 +530,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            with devices.autocast():
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)

            samples_ddim = samples_ddim.to(devices.dtype_vae)
            x_samples_ddim = decode_first_stage(p.sd_model, samples_ddim)
            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
            x_samples_ddim = torch.stack(x_samples_ddim).float()
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

            del samples_ddim