Commit f4ec411f authored by ULTRANOX\Chris's avatar ULTRANOX\Chris
Browse files

Allow checkpoint merger to merge pix2pix models in the same way that it...

Allow checkpoint merger to merge pix2pix models in the same way that it currently supports inpainting models.
parent 6cff4401
Loading
Loading
Loading
Loading
+11 −5
Original line number Diff line number Diff line
@@ -132,6 +132,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None

    result_is_inpainting_model = False
    result_is_pix2pix_model = False

    if theta_func2:
        shared.state.textinfo = f"Loading B"
@@ -186,8 +187,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
                if a.shape[1] == 4 and b.shape[1] == 9:
                    raise RuntimeError("When merging inpainting model with a normal one, A must be the inpainting model.")

                if a.shape[1] == 8 and b.shape[1] == 4:#If we have an InstructPix2Pix model...
                    print("Detected possible merge of instruct model with non-instruct model.")
                    theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)#Merge only the vectors the models have in common.  Otherwise we get an error due to dimension mismatch.
                    result_is_pix2pix_model = True
                else:
                    assert a.shape[1] == 9 and b.shape[1] == 4, f"Bad dimensions for merged layer {key}: A={a.shape}, B={b.shape}"

                    theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
                    result_is_inpainting_model = True
            else:
@@ -226,6 +231,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_

    filename = filename_generator() if custom_name == '' else custom_name
    filename += ".inpainting" if result_is_inpainting_model else ""
    filename += ".pix2pix" if result_is_pix2pix_model else ""
    filename += "." + checkpoint_format

    output_modelname = os.path.join(ckpt_dir, filename)