Commit ada901ed authored by AUTOMATIC's avatar AUTOMATIC
Browse files

added console outputs, more clear indication of progress, and ability to...

added console outputs, more clear indication of progress, and ability to specify full filename to checkpoint merger
restore "Loading..." text
parent a9dc307a
Loading
Loading
Loading
Loading
+33 −15
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import numpy as np
from PIL import Image

import torch
import tqdm

from modules import processing, shared, images, devices
from modules.shared import opts
@@ -149,19 +150,35 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
        alpha = alpha * alpha * (3 - (2 * alpha))
        return theta0 + ((theta1 - theta0) * alpha)

    model_0 = torch.load('models/' + modelname_0 + '.ckpt')
    model_1 = torch.load('models/' + modelname_1 + '.ckpt')
    if os.path.exists(modelname_0):
        model0_filename = modelname_0
        modelname_0 = os.path.splitext(os.path.basename(modelname_0))[0]
    else:
        model0_filename = 'models/' + modelname_0 + '.ckpt'

    if os.path.exists(modelname_1):
        model1_filename = modelname_1
        modelname_1 = os.path.splitext(os.path.basename(modelname_1))[0]
    else:
        model1_filename = 'models/' + modelname_1 + '.ckpt'

    print(f"Loading {model0_filename}...")
    model_0 = torch.load(model0_filename, map_location='cpu')

    print(f"Loading {model1_filename}...")
    model_1 = torch.load(model1_filename, map_location='cpu')
    
    theta_0 = model_0['state_dict']
    theta_1 = model_1['state_dict']
    theta_func = weighted_sum

    if interp_method == "Weighted Sum":
        theta_func = weighted_sum
    if interp_method == "Sigmoid":
        theta_func = sigmoid
    theta_funcs = {
        "Weighted Sum": weighted_sum,
        "Sigmoid": sigmoid,
    }
    theta_func = theta_funcs[interp_method]

    for key in theta_0.keys():
    print(f"Merging...")
    for key in tqdm.tqdm(theta_0.keys()):
        if 'model' in key and key in theta_1:
            theta_0[key] = theta_func(theta_0[key], theta_1[key], interp_amount)
    
@@ -169,8 +186,9 @@ def run_modelmerger(modelname_0, modelname_1, interp_method, interp_amount):
        if 'model' in key and key not in theta_0:
            theta_0[key] = theta_1[key]

    output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt';
    
    output_modelname = 'models/' + modelname_0 + '-' + modelname_1 + '-merged.ckpt'
    print(f"Saving to {output_modelname}...")
    torch.save(model_0, output_modelname)

    return "<p>Model saved to " + output_modelname + "</p>"
    print(f"Checkpoint saved.")
    return "Checkpoint saved to " + output_modelname
+2 −1
Original line number Diff line number Diff line
@@ -49,6 +49,7 @@ sample_img2img = sample_img2img if os.path.exists(sample_img2img) else None

css_hide_progressbar = """
.wrap .m-12 svg { display:none!important; }
.wrap .m-12::before { content:"Loading..." }
.progress-bar { display:none!important; }
.meta-text { display:none!important; }
"""
@@ -865,7 +866,7 @@ def create_ui(txt2img, img2img, run_extras, run_pnginfo, run_modelmerger):
                submit = gr.Button(elem_id="modelmerger_merge", label="Merge", variant='primary')
            
            with gr.Column(variant='panel'):
                submit_result = gr.HTML(elem_id="modelmerger_result")
                submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)

            submit.click(
                fn=run_modelmerger,