Commit c77c89cc authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make main model loading and model merger use the same code

parent 050a6a79
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -170,8 +170,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, interp_method, int
    print(f"Loading {secondary_model_info.filename}...")
    secondary_model = torch.load(secondary_model_info.filename, map_location='cpu')

    theta_0 = primary_model['state_dict']
    theta_1 = secondary_model['state_dict']
    theta_0 = sd_models.get_state_dict_from_checkpoint(primary_model)
    theta_1 = sd_models.get_state_dict_from_checkpoint(secondary_model)

    theta_funcs = {
        "Weighted Sum": weighted_sum,
+9 −5
Original line number Diff line number Diff line
@@ -122,6 +122,13 @@ def select_checkpoint():
    return checkpoint_info


def get_state_dict_from_checkpoint(pl_sd):
    if "state_dict" in pl_sd:
        return pl_sd["state_dict"]

    return pl_sd


def load_model_weights(model, checkpoint_info):
    checkpoint_file = checkpoint_info.filename
    sd_model_hash = checkpoint_info.hash
@@ -132,10 +139,7 @@ def load_model_weights(model, checkpoint_info):
    if "global_step" in pl_sd:
        print(f"Global Step: {pl_sd['global_step']}")

    if "state_dict" in pl_sd:
        sd = pl_sd["state_dict"]
    else:
        sd = pl_sd
    sd = get_state_dict_from_checkpoint(pl_sd)

    model.load_state_dict(sd, strict=False)