Commit 44c46f0e authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make it possible to merge inpainting model with non-inpainting one

parent 8504db51
Loading
Loading
Loading
Loading
+25 −2
Original line number Diff line number Diff line
@@ -247,6 +247,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam
    primary_model_info = sd_models.checkpoints_list[primary_model_name]
    secondary_model_info = sd_models.checkpoints_list[secondary_model_name]
    teritary_model_info = sd_models.checkpoints_list.get(teritary_model_name, None)
    result_is_inpainting_model = False

    print(f"Loading {primary_model_info.filename}...")
    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -280,8 +281,22 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam

    for key in tqdm.tqdm(theta_0.keys()):
        if 'model' in key and key in theta_1:
            a = theta_0[key]
            b = theta_1[key]

            theta_0[key] = theta_func2(theta_0[key], theta_1[key], multiplier)
            # this enables merging an inpainting model (A) with another one (B);
            # where normal model would have 4 channels, for latenst space, inpainting model would
            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
            if a.shape != b.shape and a.shape[0:1] + a.shape[2:] == b.shape[0:1] + b.shape[2:]:
                if a.shape[1] == 4 and b.shape[1] == 9:
                    raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")

                assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"

                theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
                result_is_inpainting_model = True
            else:
                theta_0[key] = theta_func2(a, b, multiplier)

            if save_as_half:
                theta_0[key] = theta_0[key].half()
@@ -295,8 +310,16 @@ def run_modelmerger(primary_model_name, secondary_model_name, teritary_model_nam

    ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path

    filename = primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + interp_method.replace(" ", "_") + '-merged.' + checkpoint_format
    filename = \
        primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
        secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
        interp_method.replace(" ", "_") + \
        '-merged.' +  \
        ("inpainting." if result_is_inpainting_model else "") + \
        checkpoint_format

    filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)

    output_modelname = os.path.join(ckpt_dir, filename)

    print(f"Saving to {output_modelname}...")