Unverified Commit c613416a authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #12227 from AUTOMATIC1111/multiple_loaded_models

option to keep multiple models in memory
parents 0ae2767a 22ecb78b
Loading
Loading
Loading
Loading
+3 −0
Original line number Original line Diff line number Diff line
@@ -15,6 +15,9 @@ def send_everything_to_cpu():




def setup_for_low_vram(sd_model, use_medvram):
def setup_for_low_vram(sd_model, use_medvram):
    if getattr(sd_model, 'lowvram', False):
        return

    sd_model.lowvram = True
    sd_model.lowvram = True


    parents = {}
    parents = {}
+7 −3
Original line number Original line Diff line number Diff line
@@ -5,7 +5,7 @@ from types import MethodType
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr, sd_hijack_inpainting


import ldm.modules.attention
import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.modules.diffusionmodules.model
@@ -29,8 +29,12 @@ ldm.modules.attention.MemoryEfficientCrossAttention = ldm.modules.attention.Cros
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention
ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"] = ldm.modules.attention.CrossAttention


# silence new console spam from SD2
# silence new console spam from SD2
ldm.modules.attention.print = lambda *args: None
ldm.modules.attention.print = shared.ldm_print
ldm.modules.diffusionmodules.model.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = shared.ldm_print
ldm.util.print = shared.ldm_print
ldm.models.diffusion.ddpm.print = shared.ldm_print

sd_hijack_inpainting.do_inpainting_hijack()


optimizers = []
optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None
current_optimizer: sd_hijack_optimizations.SdOptimization = None
+0 −2
Original line number Original line Diff line number Diff line
@@ -92,6 +92,4 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F




def do_inpainting_hijack():
def do_inpainting_hijack():
    # p_sample_plms is needed because PLMS can't work with dicts as conditionings

    ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
    ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
+110 −25
Original line number Original line Diff line number Diff line
@@ -15,7 +15,6 @@ import ldm.modules.midas as midas
from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config


from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
from modules.timer import Timer
import tomesd
import tomesd


@@ -434,6 +433,7 @@ sdxl_refiner_clip_weight = 'conditioner.embedders.0.model.ln_final.weight'
class SdModelData:
class SdModelData:
    def __init__(self):
    def __init__(self):
        self.sd_model = None
        self.sd_model = None
        self.loaded_sd_models = []
        self.was_loaded_at_least_once = False
        self.was_loaded_at_least_once = False
        self.lock = threading.Lock()
        self.lock = threading.Lock()


@@ -448,6 +448,7 @@ class SdModelData:


                try:
                try:
                    load_model()
                    load_model()

                except Exception as e:
                except Exception as e:
                    errors.display(e, "loading stable diffusion model", full_traceback=True)
                    errors.display(e, "loading stable diffusion model", full_traceback=True)
                    print("", file=sys.stderr)
                    print("", file=sys.stderr)
@@ -459,11 +460,24 @@ class SdModelData:
    def set_sd_model(self, v):
    def set_sd_model(self, v):
        self.sd_model = v
        self.sd_model = v


        try:
            self.loaded_sd_models.remove(v)
        except ValueError:
            pass

        if v is not None:
            self.loaded_sd_models.insert(0, v)



model_data = SdModelData()
model_data = SdModelData()




def get_empty_cond(sd_model):
def get_empty_cond(sd_model):
    from modules import extra_networks, processing

    p = processing.StableDiffusionProcessingTxt2Img()
    extra_networks.activate(p, {})

    if hasattr(sd_model, 'conditioner'):
    if hasattr(sd_model, 'conditioner'):
        d = sd_model.get_learned_conditioning([""])
        d = sd_model.get_learned_conditioning([""])
        return d['crossattn']
        return d['crossattn']
@@ -471,19 +485,43 @@ def get_empty_cond(sd_model):
        return sd_model.cond_stage_model([""])
        return sd_model.cond_stage_model([""])




def send_model_to_cpu(m):
    from modules import lowvram

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.send_everything_to_cpu()
    else:
        m.to(devices.cpu)

    devices.torch_gc()


def send_model_to_device(m):
    from modules import lowvram

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
    else:
        m.to(shared.device)


def send_model_to_trash(m):
    m.to(device="meta")
    devices.torch_gc()


def load_model(checkpoint_info=None, already_loaded_state_dict=None):
def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    from modules import lowvram, sd_hijack
    from modules import sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()
    checkpoint_info = checkpoint_info or select_checkpoint()


    timer = Timer()

    if model_data.sd_model:
    if model_data.sd_model:
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        send_model_to_trash(model_data.sd_model)
        model_data.sd_model = None
        model_data.sd_model = None
        gc.collect()
        devices.torch_gc()
        devices.torch_gc()


    do_inpainting_hijack()
    timer.record("unload existing model")

    timer = Timer()


    if already_loaded_state_dict is not None:
    if already_loaded_state_dict is not None:
        state_dict = already_loaded_state_dict
        state_dict = already_loaded_state_dict
@@ -523,12 +561,9 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):


    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, devices.cpu):
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
    timer.record("load weights from state dict")


    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
    send_model_to_device(sd_model)
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
    else:
        sd_model.to(shared.device)

    timer.record("move model to device")
    timer.record("move model to device")


    sd_hijack.model_hijack.hijack(sd_model)
    sd_hijack.model_hijack.hijack(sd_model)
@@ -536,7 +571,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    timer.record("hijack")
    timer.record("hijack")


    sd_model.eval()
    sd_model.eval()
    model_data.sd_model = sd_model
    model_data.set_sd_model(sd_model)
    model_data.was_loaded_at_least_once = True
    model_data.was_loaded_at_least_once = True


    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model
@@ -557,10 +592,61 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    return sd_model
    return sd_model




def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
    """
    Checks if the desired checkpoint from checkpoint_info is not already loaded in model_data.loaded_sd_models.
    If it is loaded, returns that (moving it to GPU if necessary, and moving the currently loadded model to CPU if necessary).
    If not, returns the model that can be used to load weights from checkpoint_info's file.
    If no such model exists, returns None.
    Additionaly deletes loaded models that are over the limit set in settings (sd_checkpoints_limit).
    """

    already_loaded = None
    for i in reversed(range(len(model_data.loaded_sd_models))):
        loaded_model = model_data.loaded_sd_models[i]
        if loaded_model.sd_checkpoint_info.filename == checkpoint_info.filename:
            already_loaded = loaded_model
            continue

        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
            print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
            model_data.loaded_sd_models.pop()
            send_model_to_trash(loaded_model)
            timer.record("send model to trash")

        if shared.opts.sd_checkpoints_keep_in_cpu:
            send_model_to_cpu(sd_model)
            timer.record("send model to cpu")

    if already_loaded is not None:
        send_model_to_device(already_loaded)
        timer.record("send model to device")

        model_data.set_sd_model(already_loaded)
        print(f"Using already loaded model {already_loaded.sd_checkpoint_info.title}: done in {timer.summary()}")
        return model_data.sd_model
    elif shared.opts.sd_checkpoints_limit > 1 and len(model_data.loaded_sd_models) < shared.opts.sd_checkpoints_limit:
        print(f"Loading model {checkpoint_info.title} ({len(model_data.loaded_sd_models) + 1} out of {shared.opts.sd_checkpoints_limit})")

        model_data.sd_model = None
        load_model(checkpoint_info)
        return model_data.sd_model
    elif len(model_data.loaded_sd_models) > 0:
        sd_model = model_data.loaded_sd_models.pop()
        model_data.sd_model = sd_model

        print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
        return sd_model
    else:
        return None


def reload_model_weights(sd_model=None, info=None):
def reload_model_weights(sd_model=None, info=None):
    from modules import lowvram, devices, sd_hijack
    from modules import devices, sd_hijack
    checkpoint_info = info or select_checkpoint()
    checkpoint_info = info or select_checkpoint()


    timer = Timer()

    if not sd_model:
    if not sd_model:
        sd_model = model_data.sd_model
        sd_model = model_data.sd_model


@@ -569,19 +655,17 @@ def reload_model_weights(sd_model=None, info=None):
    else:
    else:
        current_checkpoint_info = sd_model.sd_checkpoint_info
        current_checkpoint_info = sd_model.sd_checkpoint_info
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
            return
            return sd_model

        sd_unet.apply_unet("None")


        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
    sd_model = reuse_model_from_already_loaded(sd_model, checkpoint_info, timer)
            lowvram.send_everything_to_cpu()
    if sd_model is not None and sd_model.sd_checkpoint_info.filename == checkpoint_info.filename:
        else:
        return sd_model
            sd_model.to(devices.cpu)


    if sd_model is not None:
        sd_unet.apply_unet("None")
        send_model_to_cpu(sd_model)
        sd_hijack.model_hijack.undo_hijack(sd_model)
        sd_hijack.model_hijack.undo_hijack(sd_model)


    timer = Timer()

    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)
    state_dict = get_checkpoint_state_dict(checkpoint_info, timer)


    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
@@ -590,9 +674,8 @@ def reload_model_weights(sd_model=None, info=None):


    if sd_model is None or checkpoint_config != sd_model.used_config:
    if sd_model is None or checkpoint_config != sd_model.used_config:
        if sd_model is not None:
        if sd_model is not None:
            sd_model.to(device="meta")
            send_model_to_trash(sd_model)


        devices.torch_gc()
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
        return model_data.sd_model
        return model_data.sd_model


@@ -615,6 +698,8 @@ def reload_model_weights(sd_model=None, info=None):


    print(f"Weights loaded in {timer.summary()}.")
    print(f"Weights loaded in {timer.summary()}.")


    model_data.set_sd_model(sd_model)

    return sd_model
    return sd_model




+4 −4
Original line number Original line Diff line number Diff line
@@ -98,10 +98,10 @@ def extend_sdxl(model):
    model.conditioner.wrapped = torch.nn.Module()
    model.conditioner.wrapped = torch.nn.Module()




sgm.modules.attention.print = lambda *args: None
sgm.modules.attention.print = shared.ldm_print
sgm.modules.diffusionmodules.model.print = lambda *args: None
sgm.modules.diffusionmodules.model.print = shared.ldm_print
sgm.modules.diffusionmodules.openaimodel.print = lambda *args: None
sgm.modules.diffusionmodules.openaimodel.print = shared.ldm_print
sgm.modules.encoders.modules.print = lambda *args: None
sgm.modules.encoders.modules.print = shared.ldm_print


# this gets the code to load the vanilla attention that we override
# this gets the code to load the vanilla attention that we override
sgm.modules.attention.SDP_IS_AVAILABLE = True
sgm.modules.attention.SDP_IS_AVAILABLE = True
Loading