Unverified Commit f126986b authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #4098 from jn-jairo/load-model

Unload sd_model before loading the other to solve the issue #3449
parents 08744040 af758e97
Loading
Loading
Loading
Loading
+13 −8
Original line number Diff line number Diff line
@@ -38,13 +38,18 @@ def setup_for_low_vram(sd_model, use_medvram):
    # see below for register_forward_pre_hook;
    # first_stage_model does not use forward(), it uses encode/decode, so register_forward_pre_hook is
    # useless here, and we just replace those methods
    def first_stage_model_encode_wrap(self, encoder, x):
        send_me_to_gpu(self, None)
        return encoder(x)

    def first_stage_model_decode_wrap(self, decoder, z):
        send_me_to_gpu(self, None)
        return decoder(z)
    first_stage_model = sd_model.first_stage_model
    first_stage_model_encode = sd_model.first_stage_model.encode
    first_stage_model_decode = sd_model.first_stage_model.decode

    def first_stage_model_encode_wrap(x):
        send_me_to_gpu(first_stage_model, None)
        return first_stage_model_encode(x)

    def first_stage_model_decode_wrap(z):
        send_me_to_gpu(first_stage_model, None)
        return first_stage_model_decode(z)

    # remove three big modules, cond, first_stage, and unet from the model and then
    # send the model to GPU. Then put modules back. the modules will be in CPU.
@@ -56,8 +61,8 @@ def setup_for_low_vram(sd_model, use_medvram):
    # register hooks for those the first two models
    sd_model.cond_stage_model.transformer.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.register_forward_pre_hook(send_me_to_gpu)
    sd_model.first_stage_model.encode = lambda x, en=sd_model.first_stage_model.encode: first_stage_model_encode_wrap(sd_model.first_stage_model, en, x)
    sd_model.first_stage_model.decode = lambda z, de=sd_model.first_stage_model.decode: first_stage_model_decode_wrap(sd_model.first_stage_model, de, z)
    sd_model.first_stage_model.encode = first_stage_model_encode_wrap
    sd_model.first_stage_model.decode = first_stage_model_decode_wrap
    parents[sd_model.cond_stage_model.transformer] = sd_model.cond_stage_model

    if use_medvram:
+3 −0
Original line number Diff line number Diff line
@@ -597,6 +597,9 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
    if p.scripts is not None:
        p.scripts.postprocess(p, res)

    p.sd_model = None
    p.sampler = None

    return res


+4 −0
Original line number Diff line number Diff line
@@ -94,6 +94,10 @@ class StableDiffusionModelHijack:
        if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
            model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped

        self.layers = None
        self.circular_enabled = False
        self.clip = None

    def apply_circular(self, enable):
        if self.circular_enabled == enable:
            return
+13 −1
Original line number Diff line number Diff line
import collections
import os.path
import sys
import gc
from collections import namedtuple
import torch
import re
@@ -220,6 +221,12 @@ def load_model(checkpoint_info=None):
    if checkpoint_info.config != shared.cmd_opts.config:
        print(f"Loading config from: {checkpoint_info.config}")

    if shared.sd_model:
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
        shared.sd_model = None
        gc.collect()
        devices.torch_gc()

    sd_config = OmegaConf.load(checkpoint_info.config)
    
    if should_hijack_inpainting(checkpoint_info):
@@ -233,6 +240,7 @@ def load_model(checkpoint_info=None):
        checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))

    do_inpainting_hijack()

    sd_model = instantiate_from_config(sd_config.model)
    load_model_weights(sd_model, checkpoint_info)

@@ -252,14 +260,18 @@ def load_model(checkpoint_info=None):
    return sd_model


def reload_model_weights(sd_model, info=None):
def reload_model_weights(sd_model=None, info=None):
    from modules import lowvram, devices, sd_hijack
    checkpoint_info = info or select_checkpoint()

    if not sd_model:
        sd_model = shared.sd_model

    if sd_model.sd_model_checkpoint == checkpoint_info.filename:
        return

    if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
        del sd_model
        checkpoints_loaded.clear()
        load_model(checkpoint_info)
        return shared.sd_model
+1 −1
Original line number Diff line number Diff line
@@ -78,7 +78,7 @@ def initialize():
    modules.scripts.load_scripts()

    modules.sd_models.load_model()
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights(shared.sd_model)))
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: modules.hypernetworks.hypernetwork.load_hypernetwork(shared.opts.sd_hypernetwork)))
    shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)