Commit 97431f29 authored by aria1th's avatar aria1th
Browse files

fix double gc and decoding with unet context

parent ffd0f8dd
Loading
Loading
Loading
Loading
+2 −3
Original line number Original line Diff line number Diff line
@@ -874,7 +874,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            else:
            else:
                if opts.sd_vae_decode_method != 'Full':
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
                with hypertile_context_unet(p.sd_model.model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
                with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts):
                    x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)
                    x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)


            x_samples_ddim = torch.stack(x_samples_ddim).float()
            x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -1146,11 +1146,11 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):
        tile_size = largest_tile_size_available(self.width, self.height)
        tile_size = largest_tile_size_available(self.width, self.height)
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
                devices.torch_gc()
                samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
                samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
        del x
        del x
        if not self.enable_hr:
        if not self.enable_hr:
            return samples
            return samples
        devices.torch_gc()


        if self.latent_scale_mode is None:
        if self.latent_scale_mode is None:
            with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
@@ -1536,7 +1536,6 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        tile_size = largest_tile_size_available(self.width, self.height)
        tile_size = largest_tile_size_available(self.width, self.height)
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
                devices.torch_gc()
                samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)
                samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)


        if self.mask is not None:
        if self.mask is not None: