Commit fbaf6e4f authored by space-nuko's avatar space-nuko
Browse files

Namespace metadata fields

parent 7c016dd6
Loading
Loading
Loading
Loading
+6 −6
Original line number Diff line number Diff line
@@ -242,7 +242,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    shared.state.textinfo = "Saving"
    print(f"Saving to {output_modelname}...")

    metadata = {"format": "pt", "models": {}, "merge_recipe": None}
    metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}

    if save_metadata:
        merge_recipe = {
@@ -260,17 +260,17 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
            "is_inpainting": result_is_inpainting_model,
            "is_instruct_pix2pix": result_is_instruct_pix2pix_model
        }
        metadata["merge_recipe"] = json.dumps(merge_recipe)
        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

        def add_model_metadata(checkpoint_info):
            checkpoint_info.calculate_shorthash()
            metadata["models"][checkpoint_info.sha256] = {
            metadata["sd_merge_models"][checkpoint_info.sha256] = {
                "name": checkpoint_info.name,
                "legacy_hash": checkpoint_info.hash,
                "merge_recipe": checkpoint_info.metadata.get("merge_recipe", None)
                "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
            }

            metadata["models"].update(checkpoint_info.metadata.get("models", {}))
            metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))

        add_model_metadata(primary_model_info)
        if secondary_model_info:
@@ -278,7 +278,7 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
        if tertiary_model_info:
            add_model_metadata(tertiary_model_info)

        metadata["models"] = json.dumps(metadata["models"])
        metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])

    _, extension = os.path.splitext(output_modelname)
    if extension.lower() == ".safetensors":