Unverified Commit 1d9dc48e authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

init job and add info to model merge

parent e9fb9bb0
Loading
Loading
Loading
Loading
+12 −2
Original line number Diff line number Diff line
@@ -242,6 +242,9 @@ def run_pnginfo(image):


def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format):
    shared.state.begin()
    shared.state.job = 'model-merge'

    def weighted_sum(theta0, theta1, alpha):
        return ((1 - alpha) * theta0) + (alpha * theta1)

@@ -263,8 +266,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
    theta_func1, theta_func2 = theta_funcs[interp_method]

    if theta_func1 and not tertiary_model_info:
        shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
        shared.state.end()
        return ["Failed: Interpolation method requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]

    shared.state.textinfo = f"Loading {secondary_model_info.filename}..."
    print(f"Loading {secondary_model_info.filename}...")
    theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')

@@ -281,6 +287,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
                    theta_1[key] = torch.zeros_like(theta_1[key])
        del theta_2

    shared.state.textinfo = f"Loading {primary_model_info.filename}..."
    print(f"Loading {primary_model_info.filename}...")
    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')

@@ -291,6 +298,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
            a = theta_0[key]
            b = theta_1[key]

            shared.state.textinfo = f'Merging layer {key}'
            # this enables merging an inpainting model (A) with another one (B);
            # where normal model would have 4 channels, for latenst space, inpainting model would
            # have another 4 channels for unmasked picture's latent space, plus one channel for mask, for a total of 9
@@ -303,8 +311,6 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
                theta_0[key][:, 0:4, :, :] = theta_func2(a[:, 0:4, :, :], b, multiplier)
                result_is_inpainting_model = True
            else:
                assert a.shape == b.shape, f'Incompatible shapes for layer {key}: A is {a.shape}, and B is {b.shape}'

                theta_0[key] = theta_func2(a, b, multiplier)

            if save_as_half:
@@ -332,6 +338,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam

    output_modelname = os.path.join(ckpt_dir, filename)

    shared.state.textinfo = f"Saving to {output_modelname}..."
    print(f"Saving to {output_modelname}...")

    _, extension = os.path.splitext(output_modelname)
@@ -343,4 +350,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
    sd_models.list_models()

    print("Checkpoint saved.")
    shared.state.textinfo = "Checkpoint saved to " + output_modelname
    shared.state.end()

    return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]