Commit ac0aa2b1 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

loading SD VAE, see PR #3303

parent 3d898044
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -155,6 +155,9 @@ def get_state_dict_from_checkpoint(pl_sd):
    return pl_sd


vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}


def load_model_weights(model, checkpoint_info):
    checkpoint_file = checkpoint_info.filename
    sd_model_hash = checkpoint_info.hash
@@ -186,7 +189,7 @@ def load_model_weights(model, checkpoint_info):
        if os.path.exists(vae_file):
            print(f"Loading VAE weights from: {vae_file}")
            vae_ckpt = torch.load(vae_file, map_location=shared.weight_load_location)
            vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss"}
            vae_dict = {k: v for k, v in vae_ckpt["state_dict"].items() if k[0:4] != "loss" and k not in vae_ignore_keys}
            model.first_stage_model.load_state_dict(vae_dict)

        model.first_stage_model.to(devices.dtype_vae)