Commit abc1e79a authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur
Browse files

Fix base VAE caching was done after loading VAE, also add safeguard

parent 8ab49274
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -220,6 +220,7 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
    model.sd_model_checkpoint = checkpoint_file
    model.sd_checkpoint_info = checkpoint_info

    sd_vae.clear_loaded_vae()
    sd_vae.load_vae(model, vae_file)


+8 −11
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ vae_path = os.path.abspath(os.path.join(models_path, vae_dir))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}


default_vae_dict = {"auto": "auto", "None": "None"}
default_vae_dict = {"auto": "auto", "None": None, None: None}
default_vae_list = ["auto", "None"]


@@ -39,6 +39,7 @@ def get_base_vae(model):
def store_base_vae(model):
    global base_vae, checkpoint_info
    if checkpoint_info != model.sd_checkpoint_info:
        assert not loaded_vae_file, "Trying to store non-base VAE!"
        base_vae = model.first_stage_model.state_dict().copy()
        checkpoint_info = model.sd_checkpoint_info

@@ -50,9 +51,11 @@ def delete_base_vae():


def restore_base_vae(model):
    global loaded_vae_file
    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info:
        print("Restoring base VAE")
        load_vae_dict(model, base_vae)
        loaded_vae_file = None
    delete_base_vae()


@@ -140,10 +143,10 @@ def load_vae(model, vae_file=None):

    if vae_file:
        print(f"Loading VAE weights from: {vae_file}")
        store_base_vae(model)
        vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
        vae_dict_1 = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
        load_vae_dict(model, vae_dict_1)
        store_base_vae(model)

        # If vae used is not in dict, update it
        # It will be removed on refresh though
@@ -157,15 +160,6 @@ def load_vae(model, vae_file=None):

    loaded_vae_file = vae_file

    """
    # Save current VAE to VAE settings, maybe? will it work?
    if save_settings:
        if vae_file is None:
            vae_opt = "None"

        # shared.opts.sd_vae = vae_opt
    """

    first_load = False


@@ -174,6 +168,9 @@ def load_vae_dict(model, vae_dict_1):
    model.first_stage_model.load_state_dict(vae_dict_1)
    model.first_stage_model.to(devices.dtype_vae)

def clear_loaded_vae():
    global loaded_vae_file
    loaded_vae_file = None

def reload_vae_weights(sd_model=None, vae_file="auto"):
    from modules import lowvram, devices, sd_hijack