Unverified Commit 2cfcb23c authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #1283 from jn-jairo/fix-vram

Fix memory leak and reduce memory usage
parents 82eb8ea4 b66aa334
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -100,6 +100,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v

        outputs.append(image)

    devices.torch_gc()

    return outputs, plaintext_to_html(info), ''


+15 −1
Original line number Diff line number Diff line
@@ -11,7 +11,7 @@ import cv2
from skimage import exposure

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -382,6 +382,13 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
            x_samples_ddim = p.sd_model.decode_first_stage(samples_ddim)
            x_samples_ddim = torch.clamp((x_samples_ddim + 1.0) / 2.0, min=0.0, max=1.0)

            del samples_ddim

            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
                lowvram.send_everything_to_cpu()

            devices.torch_gc()

            if opts.filter_nsfw:
                import modules.safety as safety
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
@@ -426,6 +433,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
                infotexts.append(infotext(n, i))
                output_images.append(image)

            del x_samples_ddim 

            devices.torch_gc()

            state.nextjob()

        p.color_corrections = None
@@ -663,4 +674,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        if self.mask is not None:
            samples = samples * self.nmask + self.init_latent * self.mask

        del x
        devices.torch_gc()

        return samples
+3 −1
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@ import traceback
import torch
import numpy as np
from torch import einsum
from torch.nn.functional import silu

import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
@@ -19,11 +20,12 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At


def apply_optimizations():
    ldm.modules.diffusionmodules.model.nonlinearity = silu

    if cmd_opts.opt_split_attention_v1:
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
    elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
        ldm.modules.diffusionmodules.model.nonlinearity = sd_hijack_optimizations.nonlinearity_hijack
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward


+0 −8
Original line number Diff line number Diff line
@@ -92,14 +92,6 @@ def split_cross_attention_forward(self, x, context=None, mask=None):

    return self.to_out(r2)

def nonlinearity_hijack(x):
    # swish
    t = torch.sigmoid(x)
    x *= t
    del t

    return x

def cross_attention_attnblock_forward(self, x):
        h_ = x
        h_ = self.norm(h_)