Commit 07be13ca authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add metadata to checkpoint merger

parent 6d3a0c95
Loading
Loading
Loading
Loading
+33 −6
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ import json
import torch
import tqdm

from modules import shared, images, sd_models, sd_vae, sd_models_config
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
    return tensor


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
    metadata = {}

    for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
        checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
        if checkpoint_info is None:
            continue

        metadata.update(checkpoint_info.metadata)

    return json.dumps(metadata, indent=4, ensure_ascii=False)


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
    shared.state.begin(job="model-merge")

    def fail(message):
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    shared.state.textinfo = "Saving"
    print(f"Saving to {output_modelname}...")

    metadata = None
    metadata = {}

    if save_metadata and copy_metadata_fields:
        if primary_model_info:
            metadata.update(primary_model_info.metadata)
        if secondary_model_info:
            metadata.update(secondary_model_info.metadata)
        if tertiary_model_info:
            metadata.update(tertiary_model_info.metadata)

    if save_metadata:
        metadata = {"format": "pt"}
        try:
            metadata.update(json.loads(metadata_json))
        except Exception as e:
            errors.display(e, "readin metadata from json")

        metadata["format"] = "pt"

    if save_metadata and add_merge_recipe:
        merge_recipe = {
            "type": "webui", # indicate this model was merged with webui's built-in merger
            "primary_model_hash": primary_model_info.sha256,
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
            "is_inpainting": result_is_inpainting_model,
            "is_instruct_pix2pix": result_is_instruct_pix2pix_model
        }
        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

        sd_merge_models = {}

@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
        if tertiary_model_info:
            add_model_metadata(tertiary_model_info)

        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
        metadata["sd_merge_models"] = json.dumps(sd_merge_models)

    _, extension = os.path.splitext(output_modelname)
    if extension.lower() == ".safetensors":
        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
    else:
        torch.save(theta_0, output_modelname)

+1 −1
Original line number Diff line number Diff line
@@ -85,7 +85,7 @@ class CheckpointInfo:
        if self.shorthash not in self.ids:
            self.ids += [self.shorthash, self.sha256, f'{self.name} [{self.shorthash}]']

        checkpoints_list.pop(self.title)
        checkpoints_list.pop(self.title, None)
        self.title = f'{self.name} [{self.shorthash}]'
        self.register()

+18 −2
Original line number Diff line number Diff line
@@ -51,7 +51,6 @@ class UiCheckpointMerger:
                    with FormRow():
                        self.checkpoint_format = gr.Radio(choices=["ckpt", "safetensors"], value="safetensors", label="Checkpoint format", elem_id="modelmerger_checkpoint_format")
                        self.save_as_half = gr.Checkbox(value=False, label="Save as float16", elem_id="modelmerger_save_as_half")
                        self.save_metadata = gr.Checkbox(value=True, label="Save metadata (.safetensors only)", elem_id="modelmerger_save_metadata")

                    with FormRow():
                        with gr.Column():
@@ -65,16 +64,30 @@ class UiCheckpointMerger:
                    with FormRow():
                        self.discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")

                    with gr.Row():
                    with gr.Accordion("Metadata", open=False) as metadata_editor:
                        with FormRow():
                            self.save_metadata = gr.Checkbox(value=True, label="Save metadata", elem_id="modelmerger_save_metadata")
                            self.add_merge_recipe = gr.Checkbox(value=True, label="Add merge recipe metadata", elem_id="modelmerger_add_recipe")
                            self.copy_metadata_fields = gr.Checkbox(value=True, label="Copy metadata from merged models", elem_id="modelmerger_copy_metadata")

                        self.metadata_json = gr.TextArea('{}', label="Metadata in JSON format")
                        self.read_metadata = gr.Button("Read metadata from selected checkpoints")

                    with FormRow():
                        self.modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

                with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
                    with gr.Group(elem_id="modelmerger_results_panel"):
                        self.modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)

        self.metadata_editor = metadata_editor
        self.blocks = modelmerger_interface

    def setup_ui(self, dummy_component, sd_model_checkpoint_component):
        self.checkpoint_format.change(lambda fmt: gr.update(visible=fmt == 'safetensors'), inputs=[self.checkpoint_format], outputs=[self.metadata_editor], show_progress=False)

        self.read_metadata.click(extras.read_metadata, inputs=[self.primary_model_name, self.secondary_model_name, self.tertiary_model_name], outputs=[self.metadata_json])

        self.modelmerger_merge.click(fn=lambda: '', inputs=[], outputs=[self.modelmerger_result])
        self.modelmerger_merge.click(
            fn=call_queue.wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
@@ -93,6 +106,9 @@ class UiCheckpointMerger:
                self.bake_in_vae,
                self.discard_weights,
                self.save_metadata,
                self.add_merge_recipe,
                self.copy_metadata_fields,
                self.metadata_json,
            ],
            outputs=[
                self.primary_model_name,