Unverified Commit 89877643 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #9312 from space-nuko/save-merge-recipe

Embed model merge metadata in .safetensors file
parents 31dbec6b fbaf6e4f
Loading
Loading
Loading
Loading
+44 −2
Original line number Diff line number Diff line
import os
import re
import shutil
import json


import torch
@@ -71,7 +72,7 @@ def to_half(tensor, enable):
    return tensor


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
    shared.state.begin()
    shared.state.job = 'model-merge'

@@ -241,13 +242,54 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    shared.state.textinfo = "Saving"
    print(f"Saving to {output_modelname}...")

    metadata = {"format": "pt", "sd_merge_models": {}, "sd_merge_recipe": None}

    if save_metadata:
        merge_recipe = {
            "type": "webui", # indicate this model was merged with webui's built-in merger
            "primary_model_hash": primary_model_info.sha256,
            "secondary_model_hash": secondary_model_info.sha256 if secondary_model_info else None,
            "tertiary_model_hash": tertiary_model_info.sha256 if tertiary_model_info else None,
            "interp_method": interp_method,
            "multiplier": multiplier,
            "save_as_half": save_as_half,
            "custom_name": custom_name,
            "config_source": config_source,
            "bake_in_vae": bake_in_vae,
            "discard_weights": discard_weights,
            "is_inpainting": result_is_inpainting_model,
            "is_instruct_pix2pix": result_is_instruct_pix2pix_model
        }
        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

        def add_model_metadata(checkpoint_info):
            checkpoint_info.calculate_shorthash()
            metadata["sd_merge_models"][checkpoint_info.sha256] = {
                "name": checkpoint_info.name,
                "legacy_hash": checkpoint_info.hash,
                "sd_merge_recipe": checkpoint_info.metadata.get("sd_merge_recipe", None)
            }

            metadata["sd_merge_models"].update(checkpoint_info.metadata.get("sd_merge_models", {}))

        add_model_metadata(primary_model_info)
        if secondary_model_info:
            add_model_metadata(secondary_model_info)
        if tertiary_model_info:
            add_model_metadata(tertiary_model_info)

        metadata["sd_merge_models"] = json.dumps(metadata["sd_merge_models"])

    _, extension = os.path.splitext(output_modelname)
    if extension.lower() == ".safetensors":
        safetensors.torch.save_file(theta_0, output_modelname, metadata={"format": "pt"})
        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
    else:
        torch.save(theta_0, output_modelname)

    sd_models.list_models()
    created_model = next((ckpt for ckpt in sd_models.checkpoints_list.values() if ckpt.name == filename), None)
    if created_model:
        created_model.calculate_shorthash()

    create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)

+10 −1
Original line number Diff line number Diff line
@@ -52,6 +52,15 @@ class CheckpointInfo:

        self.ids = [self.hash, self.model_name, self.title, name, f'{name} [{self.hash}]'] + ([self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]'] if self.shorthash else [])

        self.metadata = {}

        _, ext = os.path.splitext(self.filename)
        if ext.lower() == ".safetensors":
            try:
                self.metadata = read_metadata_from_safetensors(filename)
            except Exception as e:
                errors.display(e, f"reading checkpoint metadata: {filename}")

    def register(self):
        checkpoints_list[self.title] = self
        for id in self.ids:
+3 −1
Original line number Diff line number Diff line
@@ -1019,8 +1019,9 @@ def create_ui():
                interp_method.change(fn=update_interp_description, inputs=[interp_method], outputs=[interp_description])

                with FormRow():
                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="ckpt", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
                    checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
                    save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
                    save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")

                with FormRow():
                    with gr.Column():
@@ -1658,6 +1659,7 @@ def create_ui():
                config_source,
                bake_in_vae,
                discard_weights,
                save_metadata,
            ],
            outputs=[
                primary_model_name,