Commit c7e50425 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add progress bar to modelmerger

parent 7cfc6450
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -172,6 +172,17 @@ function submit_img2img(){
    return res
}

function modelmerger(){
    var id = randomId()
    requestProgress(id, gradioApp().getElementById('modelmerger_results_panel'), null, function(){})

    gradioApp().getElementById('modelmerger_result').innerHTML = ''

    var res = create_submit_args(arguments)
    res[0] = id
    return res
}


function ask_for_style_name(_, prompt_text, negative_prompt_text) {
    name_ = prompt('Style name:')
+15 −3
Original line number Diff line number Diff line
@@ -274,14 +274,15 @@ def create_config(ckpt_result, config_source, a, b, c):
    shutil.copyfile(cfg, checkpoint_filename)


def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source):
    shared.state.begin()
    shared.state.job = 'model-merge'
    shared.state.job_count = 1

    def fail(message):
        shared.state.textinfo = message
        shared.state.end()
        return [message, *[gr.update() for _ in range(4)]]
        return [*[gr.update() for _ in range(4)], message]

    def weighted_sum(theta0, theta1, alpha):
        return ((1 - alpha) * theta0) + (alpha * theta1)
@@ -320,9 +321,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
    theta_1 = sd_models.read_state_dict(secondary_model_info.filename, map_location='cpu')

    if theta_func1:
        shared.state.job_count += 1

        print(f"Loading {tertiary_model_info.filename}...")
        theta_2 = sd_models.read_state_dict(tertiary_model_info.filename, map_location='cpu')

        shared.state.sampling_steps = len(theta_1.keys())
        for key in tqdm.tqdm(theta_1.keys()):
            if 'model' in key:
                if key in theta_2:
@@ -330,8 +334,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
                    theta_1[key] = theta_func1(theta_1[key], t2)
                else:
                    theta_1[key] = torch.zeros_like(theta_1[key])

            shared.state.sampling_step += 1
        del theta_2

        shared.state.nextjob()

    shared.state.textinfo = f"Loading {primary_model_info.filename}..."
    print(f"Loading {primary_model_info.filename}...")
    theta_0 = sd_models.read_state_dict(primary_model_info.filename, map_location='cpu')
@@ -340,6 +348,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam

    chckpoint_dict_skip_on_merge = ["cond_stage_model.transformer.text_model.embeddings.position_ids"]

    shared.state.sampling_steps = len(theta_0.keys())
    for key in tqdm.tqdm(theta_0.keys()):
        if 'model' in key and key in theta_1:

@@ -367,6 +376,8 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
            if save_as_half:
                theta_0[key] = theta_0[key].half()

        shared.state.sampling_step += 1

    # I believe this part should be discarded, but I'll leave it for now until I am sure
    for key in theta_1.keys():
        if 'model' in key and key not in theta_0:
@@ -393,6 +404,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam

    output_modelname = os.path.join(ckpt_dir, filename)

    shared.state.nextjob()
    shared.state.textinfo = f"Saving to {output_modelname}..."
    print(f"Saving to {output_modelname}...")

@@ -410,4 +422,4 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
    shared.state.textinfo = "Checkpoint saved to " + output_modelname
    shared.state.end()

    return ["Checkpoint saved to " + output_modelname] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
    return [*[gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)], "Checkpoint saved to " + output_modelname]
+1 −1
Original line number Diff line number Diff line
@@ -72,7 +72,7 @@ def progressapi(req: ProgressRequest):

    if job_count > 0:
        progress += job_no / job_count
    if sampling_steps > 0:
    if sampling_steps > 0 and job_count > 0:
        progress += 1 / job_count * sampling_step / sampling_steps

    progress = min(progress, 1)
+8 −5
Original line number Diff line number Diff line
@@ -1208,8 +1208,9 @@ def create_ui():
                with gr.Row():
                    modelmerger_merge = gr.Button(elem_id="modelmerger_merge", value="Merge", variant='primary')

            with gr.Column(variant='panel'):
                submit_result = gr.Textbox(elem_id="modelmerger_result", show_label=False)
            with gr.Column(variant='compact', elem_id="modelmerger_results_container"):
                with gr.Group(elem_id="modelmerger_results_panel"):
                    modelmerger_result = gr.HTML(elem_id="modelmerger_result", show_label=False)

    with gr.Blocks(analytics_enabled=False) as train_interface:
        with gr.Row().style(equal_height=False):
@@ -1753,12 +1754,14 @@ def create_ui():
                print("Error loading/saving model file:", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
                modules.sd_models.list_models()  # to remove the potentially missing models from the list
                return [f"Error merging checkpoints: {e}"] + [gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)]
                return [*[gr.Dropdown.update(choices=modules.sd_models.checkpoint_tiles()) for _ in range(4)], f"Error merging checkpoints: {e}"]
            return results

        modelmerger_merge.click(
            fn=modelmerger,
            fn=wrap_gradio_gpu_call(modelmerger, extra_outputs=lambda: [gr.update() for _ in range(4)]),
            _js='modelmerger',
            inputs=[
                dummy_component,
                primary_model_name,
                secondary_model_name,
                tertiary_model_name,
@@ -1770,11 +1773,11 @@ def create_ui():
                config_source,
            ],
            outputs=[
                submit_result,
                primary_model_name,
                secondary_model_name,
                tertiary_model_name,
                component_dict['sd_model_checkpoint'],
                modelmerger_result,
            ]
        )

+5 −0
Original line number Diff line number Diff line
@@ -737,6 +737,11 @@ footer {
    line-height: 2.4em;
}

#modelmerger_results_container{
    margin-top: 1em;
    overflow: visible;
}

/* The following handles localization for right-to-left (RTL) languages like Arabic.
The rtl media type will only be activated by the logic in javascript/localization.js.
If you change anything above, you need to make sure it is RTL compliant by just running