Commit 7cfc6450 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

eliminate repetition of code in #6910

parent 01b1061a
Loading
Loading
Loading
Loading
+8 −9
Original line number Diff line number Diff line
@@ -278,6 +278,11 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
    shared.state.begin()
    shared.state.job = 'model-merge'

    def fail(message):
        shared.state.textinfo = message
        shared.state.end()
        return [message, *[gr.update() for _ in range(4)]]

    def weighted_sum(theta0, theta1, alpha):
        return ((1 - alpha) * theta0) + (alpha * theta1)

@@ -288,16 +293,12 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
        return theta0 + (alpha * theta1_2_diff)

    if not primary_model_name:
        shared.state.textinfo = "Failed: Merging requires a primary model."
        shared.state.end()
        return ["Failed: Merging requires a primary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
        return fail("Failed: Merging requires a primary model.")

    primary_model_info = sd_models.checkpoints_list[primary_model_name]

    if not secondary_model_name:
        shared.state.textinfo = "Failed: Merging requires a secondary model."
        shared.state.end()
        return ["Failed: Merging requires a secondary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
        return fail("Failed: Merging requires a secondary model.")
    
    secondary_model_info = sd_models.checkpoints_list[secondary_model_name]

@@ -308,9 +309,7 @@ def run_modelmerger(primary_model_name, secondary_model_name, tertiary_model_nam
    theta_func1, theta_func2 = theta_funcs[interp_method]

    if theta_func1 and not tertiary_model_name:
        shared.state.textinfo = "Failed: Interpolation method requires a tertiary model."
        shared.state.end()
        return [f"Failed: Interpolation method ({interp_method}) requires a tertiary model."] + [gr.Dropdown.update(choices=sd_models.checkpoint_tiles()) for _ in range(4)]
        return fail(f"Failed: Interpolation method ({interp_method}) requires a tertiary model.")
    
    tertiary_model_info = sd_models.checkpoints_list[tertiary_model_name] if theta_func1 else None