Commit 10aca1ca authored by AUTOMATIC's avatar AUTOMATIC
Browse files

more careful loading of model weights (eliminates some issues with checkpoints...

more careful loading of model weights (eliminates some issues with checkpoints that have weird cond_stage_model layer names)
parent c1093b80
Loading
Loading
Loading
Loading
+25 −3
Original line number Diff line number Diff line
@@ -122,11 +122,33 @@ def select_checkpoint():
    return checkpoint_info


chckpoint_dict_replacements = {
    'cond_stage_model.transformer.embeddings.': 'cond_stage_model.transformer.text_model.embeddings.',
    'cond_stage_model.transformer.encoder.': 'cond_stage_model.transformer.text_model.encoder.',
    'cond_stage_model.transformer.final_layer_norm.': 'cond_stage_model.transformer.text_model.final_layer_norm.',
}


def transform_checkpoint_dict_key(k):
    for text, replacement in chckpoint_dict_replacements.items():
        if k.startswith(text):
            k = replacement + k[len(text):]

    return k


def get_state_dict_from_checkpoint(pl_sd):
    if "state_dict" in pl_sd:
        return pl_sd["state_dict"]
        pl_sd = pl_sd["state_dict"]

    sd = {}
    for k, v in pl_sd.items():
        new_key = transform_checkpoint_dict_key(k)

        if new_key is not None:
            sd[new_key] = v

    return pl_sd
    return sd


def load_model_weights(model, checkpoint_info):
@@ -141,7 +163,7 @@ def load_model_weights(model, checkpoint_info):
            print(f"Global Step: {pl_sd['global_step']}")

        sd = get_state_dict_from_checkpoint(pl_sd)
        model.load_state_dict(sd, strict=False)
        missing, extra = model.load_state_dict(sd, strict=False)

        if shared.cmd_opts.opt_channelslast:
            model.to(memory_format=torch.channels_last)