Commit abc1e79a authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur
Browse files

Fix base VAE caching was done after loading VAE, also add safeguard

parent 8ab49274
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
    model.sd_model_checkpoint = checkpoint_file
    model.sd_model_checkpoint = checkpoint_file
    model.sd_checkpoint_info = checkpoint_info
    model.sd_checkpoint_info = checkpoint_info


    sd_vae.clear_loaded_vae()
    sd_vae.load_vae(model, vae_file)
    sd_vae.load_vae(model, vae_file)




+8 −11
Original line number Original line Diff line number Diff line
@@ -15,7 +15,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}




default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_dict = {"auto": "auto", "None": None, None: None}
default_vae_list = ["auto", "None"]
default_vae_list = ["auto", "None"]




@@ -39,6 +39,7 @@ def get_base_vae(model):
def store_base_vae(model):
def store_base_vae(model):
    global base_vae, checkpoint_info
    global base_vae, checkpoint_info
    if checkpoint_info != model.sd_checkpoint_info:
    if checkpoint_info != model.sd_checkpoint_info:
        assert not loaded_vae_file, "Trying to store non-base VAE!"
        base_vae = model.first_stage_model.state_dict().copy()
        base_vae = model.first_stage_model.state_dict().copy()
        checkpoint_info = model.sd_checkpoint_info
        checkpoint_info = model.sd_checkpoint_info


@@ -50,9 +51,11 @@ def delete_base_vae():




def restore_base_vae(model):
def restore_base_vae(model):
    global loaded_vae_file
    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
        print("Restoring base VAE")
        print("Restoring base VAE")
        load_vae_dict(model, base_vae)
        load_vae_dict(model, base_vae)
        loaded_vae_file = None
    delete_base_vae()
    delete_base_vae()




@@ -140,10 +143,10 @@ def load_vae(model, vae_file=None):


    if vae_file:
    if vae_file:
        print(f"Loading VAE weights from: {vae_file}")
        print(f"Loading VAE weights from: {vae_file}")
        store_base_vae(model)
        vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
        vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
        vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
        vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
        load_vae_dict(model, vae_dict_1)
        load_vae_dict(model, vae_dict_1)
        store_base_vae(model)


        # If vae used is not in dict, update it
        # If vae used is not in dict, update it
        # It will be removed on refresh though
        # It will be removed on refresh though
@@ -157,15 +160,6 @@ def load_vae(model, vae_file=None):


    loaded_vae_file = vae_file
    loaded_vae_file = vae_file


    """
    # Save current VAE to VAE settings, maybe? will it work?
    if save_settings:
        if vae_file is None:
            vae_opt = "None"

        # shared.opts.sd_vae = vae_opt
    """

    first_load = False
    first_load = False




@@ -174,6 +168,9 @@ def load_vae_dict(model, vae_dict_1):
    model.first_stage_model.load_state_dict(vae_dict_1)
    model.first_stage_model.load_state_dict(vae_dict_1)
    model.first_stage_model.to(devices.dtype_vae)
    model.first_stage_model.to(devices.dtype_vae)


def clear_loaded_vae():
    global loaded_vae_file
    loaded_vae_file = None


def reload_vae_weights(sd_model=None, vae_file="auto"):
def reload_vae_weights(sd_model=None, vae_file="auto"):
    from modules import lowvram, devices, sd_hijack
    from modules import lowvram, devices, sd_hijack