Commit 44c50973 authored by Ritesh Gangnani's avatar Ritesh Gangnani
Browse files

Use devices.torch_gc() instead of empty_cache()

parent 44db35fb
Loading
Loading
Loading
Loading
+1 −4
Original line number Original line Diff line number Diff line
import gc 

import torch
import torch
from torch.nn.functional import silu
from torch.nn.functional import silu
from types import MethodType
from types import MethodType
@@ -193,8 +191,7 @@ class StableDiffusionModelHijack:
                delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i)
                delattr(m.model.diffusion_model.output_blocks[1][1].transformer_blocks,i)
            delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1')
            delattr(m.model.diffusion_model.output_blocks[4][1].transformer_blocks,'1')
            delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1')
            delattr(m.model.diffusion_model.output_blocks[5][1].transformer_blocks,'1')
            torch.cuda.empty_cache()
            devices.torch_gc()
            gc.collect()


    def hijack(self, m):
    def hijack(self, m):
        conditioner = getattr(m, 'conditioner', None)
        conditioner = getattr(m, 'conditioner', None)
+0 −1
Original line number Original line Diff line number Diff line
@@ -347,7 +347,6 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
    model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
    model.is_sd2 = not model.is_sdxl and hasattr(model.cond_stage_model, 'model')
    model.is_sd1 = not model.is_sdxl and not model.is_sd2
    model.is_sd1 = not model.is_sdxl and not model.is_sd2
    model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
    model.is_ssd = model.is_sdxl and 'model.diffusion_model.middle_block.1.transformer_blocks.0.attn1.to_q.weight' not in state_dict.keys()
    
    if model.is_sdxl:
    if model.is_sdxl:
        sd_models_xl.extend_sdxl(model)
        sd_models_xl.extend_sdxl(model)