Unverified Commit 01f2ed68 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #5065 from JaySmithWpg/vram-leak

#3449 - VRAM leak when switching to/from inpainting checkpoint
parents 151e2cc6 c833d5bf
Loading
Loading
Loading
Loading
+15 −18
Original line number Diff line number Diff line
from collections import namedtuple
from collections import namedtuple, deque
import numpy as np
from math import floor
import torch
@@ -344,18 +344,28 @@ class CFGDenoiser(torch.nn.Module):


class TorchHijack:
    def __init__(self, kdiff_sampler):
        self.kdiff_sampler = kdiff_sampler
    def __init__(self, sampler_noises):
        # Using a deque to efficiently receive the sampler_noises in the same order as the previous index-based
        # implementation.
        self.sampler_noises = deque(sampler_noises)

    def __getattr__(self, item):
        if item == 'randn_like':
            return self.kdiff_sampler.randn_like
            return self.randn_like

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))

    def randn_like(self, x):
        if self.sampler_noises:
            noise = self.sampler_noises.popleft()
            if noise.shape == x.shape:
                return noise

        return torch.randn_like(x)


class KDiffusionSampler:
    def __init__(self, funcname, sd_model):
@@ -367,7 +377,6 @@ class KDiffusionSampler:
        self.extra_params = sampler_extra_params.get(funcname, [])
        self.model_wrap_cfg = CFGDenoiser(self.model_wrap)
        self.sampler_noises = None
        self.sampler_noise_index = 0
        self.stop_at = None
        self.eta = None
        self.default_eta = 1.0
@@ -400,26 +409,14 @@ class KDiffusionSampler:
    def number_of_needed_noises(self, p):
        return p.steps

    def randn_like(self, x):
        noise = self.sampler_noises[self.sampler_noise_index] if self.sampler_noises is not None and self.sampler_noise_index < len(self.sampler_noises) else None

        if noise is not None and x.shape == noise.shape:
            res = noise
        else:
            res = torch.randn_like(x)

        self.sampler_noise_index += 1
        return res

    def initialize(self, p):
        self.model_wrap_cfg.mask = p.mask if hasattr(p, 'mask') else None
        self.model_wrap_cfg.nmask = p.nmask if hasattr(p, 'nmask') else None
        self.model_wrap.step = 0
        self.sampler_noise_index = 0
        self.eta = p.eta or opts.eta_ancestral

        if self.sampler_noises is not None:
            k_diffusion.sampling.torch = TorchHijack(self)
            k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)

        extra_params_kwargs = {}
        for param_name in self.extra_params: