Commit 0f5dbfff authored by AUTOMATIC's avatar AUTOMATIC
Browse files

allow baking in VAE in checkpoint merger tab

do not save config if it's the default for checkpoint merger tab
change file naming scheme for checkpoint merger tab
allow just saving A without any merging for checkpoint merger tab
some stylistic changes for UI in checkpoint merger tab
parent c7e50425
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -92,6 +92,7 @@ titles = {

    "Weighted sum": "Result = A * (1 - M) + B * M",
    "Add difference": "Result = A + (B - C) * M",
    "No interpolation": "Result = A",

	"Initialization text": "If the number of tokens is more than the number of vectors, some may be skipped.\nLeave the textbox empty to start with zeroed out vectors",
    "Learning rate": "How fast should training go. Low values will take longer to train, high values may fail to converge (not generate accurate results) and/or may break the embedding (This has happened if you see Loss: nan in the training info textbox. If this happens, you need to manually restore your embedding from an older not-broken backup).\n\nYou can set a single numeric value, or multiple learning rates using the syntax:\n\n   rate_1:max_steps_1, rate_2:max_steps_2, ...\n\nEG:   0.005:100, 1e-3:1000, 1e-5\n\nWill train with rate of 0.005 for first 100 steps, then 1e-3 until 1000 steps, then 1e-5 for all remaining steps.",
+0 −2
Original line number Diff line number Diff line
@@ -176,8 +176,6 @@ function modelmerger(){
    var id = randomId()
    requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})

    gradioApp().getElementById('modelmerger_result').innerHTML = ''

    var res = create_submit_args(arguments)
    res[0] = id
    return res
+68 −44
Original line number Diff line number Diff line
@@ -15,7 +15,7 @@ from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass

from modules import processing, shared, images, devices, sd_models, sd_samplers
from modules import processing, shared, images, devices, sd_models, sd_samplers, sd_vae
from modules.shared import opts
import modules.gfpgan_model
from modules.ui import plaintext_to_html
@@ -251,7 +251,8 @@ def run_pnginfo(image):

def create_config(ckpt_result, config_source, a, b, c):
    def config(x):
        return sd_models.find_checkpoint_config(x) if x else None
        res = sd_models.find_checkpoint_config(x) if x else None
        return res if res != shared.sd_default_config else None

    if config_source == 0:
        cfg = config(a) or config(b) or config(c)
@@ -274,10 +275,12 @@ def create_config(ckpt_result, config_source, a, b, c):
    shutil.copyfile(cfg, checkpoint_filename)


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae):
    shared.state.begin()
    shared.state.job = 'model-merge'
    shared.state.job_count = 1

    def fail(message):
        shared.state.textinfo = message
@@ -293,21 +296,42 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    def add_difference(theta0, theta1_2_diff, alpha):
        return theta0 + (alpha * theta1_2_diff)

    def filename_weighed_sum():
        a = primary_model_info.model_name
        b = secondary_model_info.model_name
        Ma = round(1 - multiplier, 2)
        Mb = round(multiplier, 2)

        return f"{Ma}({a}) + {Mb}({b})"

    def filename_add_differnece():
        a = primary_model_info.model_name
        b = secondary_model_info.model_name
        c = tertiary_model_info.model_name
        M = round(multiplier, 2)

        return f"{a} + {M}({b} - {c})"

    def filename_nothing():
        return primary_model_info.model_name

    theta_funcs = {
        "Weighted sum": (filename_weighed_sum, None, weighted_sum),
        "Add difference": (filename_add_differnece, get_difference, add_difference),
        "No interpolation": (filename_nothing, None, None),
    }
    filename_generator, theta_func1, theta_func2 = theta_funcs[interp_method]
    shared.state.job_count = (1 if theta_func1 else 0) + (1 if theta_func2 else 0)

    if not primary_model_name:
        return fail("Failed: Merging requires a primary model.")

    primary_model_info = sd_models.checkpoints_list[primary_model_name]

    if not secondary_model_name:
    if theta_func2 and not secondary_model_name:
        return fail("Failed: Merging requires a secondary model.")

    secondary_model_info = sd_models.checkpoints_list[secondary_model_name]

    theta_funcs = {
        "Weighted sum": (None, weighted_sum),
        "Add difference": (get_difference, add_difference),
    }
    theta_func1, theta_func2 = theta_funcs[interp_method]
    secondary_model_info = sd_models.checkpoints_list[secondary_model_name] if theta_func2 else None

    if theta_func1 and not tertiary_model_name:
        return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
@@ -316,18 +340,24 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_

    result_is_inpainting_model = False

    shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
    if theta_func2:
        shared.state.textinfo = f"Loading B"
        print(f"Loading {secondary_model_info.filename}...")
        theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')
    else:
        theta_1 = None

    if theta_func1:
        shared.state.job_count += 1

        shared.state.textinfo = f"Loading C"
        print(f"Loading {tertiary_model_info.filename}...")
        theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')

        shared.state.textinfo = 'Merging B and C'
        shared.state.sampling_steps = len(theta_1.keys())
        for key in tqdm.tqdm(theta_1.keys()):
            if key in chckpoint_dict_skip_on_merge:
                continue

            if 'model' in key:
                if key in theta_2:
                    t2 = theta_2.get(key, torch.zeros_like(theta_1[key]))
@@ -345,12 +375,10 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')

    print("Merging...")

    chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]

    shared.state.textinfo = 'Merging A and B'
    shared.state.sampling_steps = len(theta_0.keys())
    for key in tqdm.tqdm(theta_0.keys()):
        if 'model' in key and key in theta_1:
        if theta_1 and 'model' in key and key in theta_1:

            if key in chckpoint_dict_skip_on_merge:
                continue
@@ -358,7 +386,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
            a = theta_0[key]
            b = theta_1[key]

            shared.state.textinfo = f'Merging layer {key}'
            # this enables merging an inpainting model (A) with another one (B);
            # where normal model would have 4 channels, for latenst space, inpainting model would
            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -378,34 +405,31 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_

        shared.state.sampling_step += 1

    # I believe this part should be discarded, but I'll leave it for now until I am sure
    for key in theta_1.keys():
        if 'model' in key and key not in theta_0:
    del theta_1

            if key in chckpoint_dict_skip_on_merge:
                continue
    bake_in_vae_filename = sd_vae.vae_dict.get(bake_in_vae, None)
    if bake_in_vae_filename is not None:
        print(f"Baking in VAE from {bake_in_vae_filename}")
        shared.state.textinfo = 'Baking in VAE'
        vae_dict = sd_vae.load_vae_dict(bake_in_vae_filename, map_location='cpu')

            theta_0[key] = theta_1[key]
            if save_as_half:
                theta_0[key] = theta_0[key].half()
    del theta_1
        for key in vae_dict.keys():
            theta_0_key = 'first_stage_model.' + key
            if theta_0_key in theta_0:
                theta_0[theta_0_key] = vae_dict[key].half() if save_as_half else vae_dict[key]

    ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path
        del vae_dict

    filename = \
        primary_model_info.model_name + '_' + str(round(1-multiplier, 2)) + '-' + \
        secondary_model_info.model_name + '_' + str(round(multiplier, 2)) + '-' + \
        interp_method.replace(" ", "_") + \
        '-merged.' +  \
        ("inpainting." if result_is_inpainting_model else "") + \
        checkpoint_format
    ckpt_dir = shared.cmd_opts.ckpt_dir or sd_models.model_path

    filename = filename if custom_name == '' else (custom_name + '.' + checkpoint_format)
    filename = filename_generator() if custom_name == '' else custom_name
    filename += ".inpainting" if result_is_inpainting_model else ""
    filename += "." + checkpoint_format

    output_modelname = os.path.join(ckpt_dir, filename)

    shared.state.nextjob()
    shared.state.textinfo = f"Saving to {output_modelname}..."
    shared.state.textinfo = "Saving"
    print(f"Saving to {output_modelname}...")

    _, extension = os.path.splitext(output_modelname)
@@ -418,8 +442,8 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_

    create_config(output_modelname, config_source, primary_model_info, secondary_model_info, tertiary_model_info)

    print("Checkpoint saved.")
    shared.state.textinfo = "Checkpoint saved to " + output_modelname
    print(f"Checkpoint saved to {output_modelname}.")
    shared.state.textinfo = "Checkpoint saved"
    shared.state.end()

    return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
+7 −2
Original line number Diff line number Diff line
@@ -120,6 +120,12 @@ def resolve_vae(checkpoint_file):
    return None, None


def load_vae_dict(filename, map_location):
    vae_ckpt = sd_models.read_state_dict(filename, map_location=map_location)
    vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
    return vae_dict_1


def load_vae(model, vae_file=None, vae_source="from unknown source"):
    global vae_dict, loaded_vae_file
    # save_settings = False
@@ -137,8 +143,7 @@ def load_vae(model, vae_file=None, vae_source="from unknown source"):
            print(f"Loading VAE weights {vae_source}: {vae_file}")
            store_base_vae(model)

            vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
            vae_dict_1 = {k: v for k, v in vae_ckpt.items() if k[0:4] != "loss" and k not in vae_ignore_keys}
            vae_dict_1 = load_vae_dict(vae_file, map_location=shared.weight_load_location)
            _load_vae_dict(model, vae_dict_1)

            if cache_enabled:
+2 −1
Original line number Diff line number Diff line
@@ -20,10 +20,11 @@ from modules.paths import models_path, script_path, sd_path

demo = None

sd_default_config = os.path.join(script_path, "configs/v1-inference.yaml")
sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, default=os.path.join(script_path, "configs/v1-inference.yaml"), help="path to config which constructs model",)
parser.add_argument("--config", type=str, default=sd_default_config, help="path to config which constructs model",)
parser.add_argument("--ckpt", type=str, default=sd_model_file, help="path to checkpoint of stable diffusion model; if specified, this checkpoint will be added to the list of checkpoints and loaded",)
parser.add_argument("--ckpt-dir", type=str, default=None, help="Path to directory with stable diffusion checkpoints")
parser.add_argument("--vae-dir", type=str, default=None, help="Path to directory with VAE files")
Loading