Commit 82380d9a authored by Jairo Correa's avatar Jairo Correa
Browse files

Removing parts no longer needed to fix vram

parent 1f50971f
Loading
Loading
Loading
Loading
+1 −2
Original line number Diff line number Diff line
import contextlib

import torch
import gc

from modules import errors

@@ -20,8 +19,8 @@ def get_optimal_device():

    return cpu


def torch_gc():
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
        torch.cuda.ipc_collect()
+8 −13
Original line number Diff line number Diff line
@@ -346,7 +346,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
            state.job_count = p.n_iter

        for n in range(p.n_iter):
        with torch.no_grad(), precision_scope("cuda"), ema_scope():
            if state.interrupted:
                break

@@ -396,21 +395,18 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)

            for i, x_sample in enumerate(x_samples_ddim):
            with torch.no_grad(), precision_scope("cuda"), ema_scope():
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
                x_sample = x_sample.astype(np.uint8)

                if p.restore_faces:
                with torch.no_grad(), precision_scope("cuda"), ema_scope():
                    if opts.save and not p.do_not_save_samples and opts.save_images_before_face_restoration:
                        images.save_image(Image.fromarray(x_sample), p.outpath_samples, "", seeds[i], prompts[i], opts.samples_format, info=infotext(n, i), p=p, suffix="-before-face-restoration")

                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()

                    x_sample = modules.face_restoration.restore_faces(x_sample)
                    devices.torch_gc()

            with torch.no_grad(), precision_scope("cuda"), ema_scope():
                image = Image.fromarray(x_sample)

                if p.color_corrections is not None and i < len(p.color_corrections):
@@ -444,7 +440,6 @@ def process_images(p: StableDiffusionProcessing) -> Processed:

            state.nextjob()

    with torch.no_grad(), precision_scope("cuda"), ema_scope():
        p.color_corrections = None

        index_of_first_image = 0