Commit 112416d0 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add option to discard weights in checkpoint merger UI

parent 0792fae0
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
from __future__ import annotations
import math
import os
import re
import sys
import traceback
import shutil
@@ -285,7 +286,7 @@ def to_half(tensor, enable):
    return tensor


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights):
    shared.state.begin()
    shared.state.job = 'model-merge'

@@ -430,6 +431,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
        for key in theta_0.keys():
            theta_0[key] = to_half(theta_0[key], save_as_half)

    if discard_weights:
        regex = re.compile(discard_weights)
        for key in list(theta_0):
            if re.search(regex, key):
                theta_0.pop(key, None)

    ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path

    filename = filename_generator() if custom_name == '' else custom_name
+4 −0
Original line number Diff line number Diff line
@@ -1248,6 +1248,9 @@ def create_ui():
                            bake_in_vae = gr.Dropdown(choices=["None"] + list(sd_vae.vae_dict), value="None", label="Bake in VAE", elem_id="modelmerger_bake_in_vae")
                            create_refresh_button(bake_in_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["None"] + list(sd_vae.vae_dict)}, "modelmerger_refresh_bake_in_vae")

                with FormRow():
                    discard_weights = gr.Textbox(value="", label="Discard weights with matching name", elem_id="modelmerger_discard_weights")

                with gr.Row():
                    modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

@@ -1838,6 +1841,7 @@ def create_ui():
                checkpoint_format,
                config_source,
                bake_in_vae,
                discard_weights,
            ],
            outputs=[
                primary_model_name,