Commit b235022c authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

option to keep multiple models in memory

parent 6f0abbb7
Loading
Loading
Loading
Loading
+3 −0
Original line number Diff line number Diff line
@@ -15,6 +15,9 @@ def send_everything_to_cpu():


def setup_for_low_vram(sd_model, use_medvram):
    if getattr(sd_model, 'lowvram', False):
        return

    sd_model.lowvram = True

    parents = {}
+4 −2
Original line number Diff line number Diff line
@@ -30,8 +30,10 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention

# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print

optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
+1 −4
Original line number Diff line number Diff line
@@ -91,7 +91,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F
    return x_prev, pred_x0, e_t


def do_inpainting_hijack():
    # p_sample_plms is needed because PLMS can't work with dicts as conditionings

ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+112 −24
Original line number Diff line number Diff line
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config

from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd

@@ -423,6 +422,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
    def __init__(self):
        self.sd_model = None
        self.loaded_sd_models = []
        self.was_loaded_at_least_once = False
        self.lock = threading.Lock()

@@ -437,6 +437,7 @@ class SdModelData:

                try:
                    load_model()

                except Exception as e:
                    errors.display(e, "loading stable diffusion model", full_traceback=True)
                    print("", file=sys.stderr)
@@ -448,11 +449,24 @@ class SdModelData:
    def set_sd_model(self, v):
        self.sd_model = v

        try:
            self.loaded_sd_models.remove(v)
        except ValueError:
            pass

        if v is not None:
            self.loaded_sd_models.insert(0, v)


model_data = SdModelData()


def get_empty_cond(sd_model):
    from modules import extra_networks, processing

    p = processing.StableDiffusionProcessingTxt2Img()
    extra_networks.activate(p, {})

    if hasattr(sd_model, 'conditioner'):
        d = sd_model.get_learned_conditioning([""])
        return d['crossattn']
@@ -460,19 +474,43 @@ def get_empty_cond(sd_model):
        return sd_model.cond_stage_model([""])


def send_model_to_cpu(m):
    from modules import lowvram

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.send_everything_to_cpu()
    else:
        m.to(devices.cpu)

    devices.torch_gc()


def send_model_to_device(m):
    from modules import lowvram

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
    else:
        m.to(shared.device)


def send_model_to_trash(m):
    m.to(device="meta")
    devices.torch_gc()


def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    from modules import lowvram, sd_hijack
    from modules import sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()

    timer = Timer()

    if model_data.sd_model:
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        send_model_to_trash(model_data.sd_model)
        model_data.sd_model = None
        gc.collect()
        devices.torch_gc()

    do_inpainting_hijack()

    timer = Timer()
    timer.record("unload existing model")

    if already_loaded_state_dict is not None:
        state_dict = already_loaded_state_dict
@@ -512,12 +550,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):

    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
    timer.record("load weights from state dict")

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
    else:
        sd_model.to(shared.device)

    send_model_to_device(sd_model)
    timer.record("move model to device")

    sd_hijack.model_hijack.hijack(sd_model)
@@ -525,7 +560,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    timer.record("hijack")

    sd_model.eval()
    model_data.sd_model = sd_model
    model_data.set_sd_model(sd_model)
    model_data.was_loaded_at_least_once = True

    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
@@ -546,10 +581,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    return sd_model


def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
    """
    Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
    If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
    If not, returns the model that can be used to load weights from checkpoint_info's file.
    If no such model exists, returns None.
    Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
    """

    already_loaded = None
    for i in reversed(range(len(model_data.loaded_sd_models))):
        loaded_model = model_data.loaded_sd_models[i]
        if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
            already_loaded = loaded_model
            continue

        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
            print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
            model_data.loaded_sd_models.pop()
            send_model_to_trash(loaded_model)
            timer.record("send model to trash")

        if shared.opts.sd_checkpoints_keep_in_cpu:
            send_model_to_cpu(sd_model)
            timer.record("send model to cpu")

    if already_loaded is not None:
        send_model_to_device(already_loaded)
        timer.record("send model to device")

        model_data.set_sd_model(already_loaded)
        print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
        return model_data.sd_model
    elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
        print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")

        model_data.sd_model = None
        load_model(checkpoint_info)
        return model_data.sd_model
    elif len(model_data.loaded_sd_models) > 0:
        sd_model = model_data.loaded_sd_models.pop()
        model_data.sd_model = sd_model

        print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
        return sd_model
    else:
        return None


def reload_model_weights(sd_model=None, info=None):
    from modules import lowvram, devices, sd_hijack
    from modules import devices, sd_hijack
    checkpoint_info = info or select_checkpoint()

    timer = Timer()

    if not sd_model:
        sd_model = model_data.sd_model

@@ -558,19 +644,17 @@ def reload_model_weights(sd_model=None, info=None):
    else:
        current_checkpoint_info = sd_model.sd_checkpoint_info
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
            return

        sd_unet.apply_unet("None")
            return sd_model

        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
            lowvram.send_everything_to_cpu()
        else:
            sd_model.to(devices.cpu)
    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
        return sd_model

    if sd_model is not None:
        sd_unet.apply_unet("None")
        send_model_to_cpu(sd_model)
        sd_hijack.model_hijack.undo_hijack(sd_model)

    timer = Timer()

    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -578,7 +662,9 @@ def reload_model_weights(sd_model=None, info=None):
    timer.record("find config")

    if sd_model is None or checkpoint_config != sd_model.used_config:
        del sd_model
        if sd_model is not None:
            send_model_to_trash(sd_model)

        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
        return model_data.sd_model

@@ -601,6 +687,8 @@ def reload_model_weights(sd_model=None, info=None):

    print(f"Weights loaded in {timer.summary()}.")

    model_data.set_sd_model(sd_model)

    return sd_model


+4 −4
Original line number Diff line number Diff line
@@ -98,10 +98,10 @@ def extend_sdxl(model):
    model.conditioner.wrapped = torch.nn.Module()


sgm.modules.attention.print = lambda *args: None
sgm.modules.diffusionmodules.model.print = lambda *args: None
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
sgm.modules.encoders.modules.print = lambda *args: None
sgm.modules.attention.print = shared.ldm_print
sgm.modules.diffusionmodules.model.print = shared.ldm_print
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
sgm.modules.encoders.modules.print = shared.ldm_print

# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
Loading