Commit 0c9b1e79 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

Merge branch 'dev' into multiple_loaded_models

parents 151b8ed3 6a0d498c
Loading
Loading
Loading
Loading
+11 −0
Original line number Diff line number Diff line
@@ -190,3 +190,14 @@ onUiUpdate(function(mutationRecords) {
        tooltipCheckTimer = setTimeout(processTooltipCheckNodes, 1000);
    }
});

onUiLoaded(function() {
    for (var comp of window.gradio_config.components) {
        if (comp.props.webui_tooltip && comp.props.elem_id) {
            var elem = gradioApp().getElementById(comp.props.elem_id);
            if (elem) {
                elem.title = comp.props.webui_tooltip;
            }
        }
    }
});
+10 −4
Original line number Diff line number Diff line
@@ -152,7 +152,11 @@ function submit() {
    showSubmitButtons('txt2img', false);

    var id = randomId();
    try {
        localStorage.setItem("txt2img_task_id", id);
    } catch (e) {
        console.warn(`Failed to save txt2img task id to localStorage: ${e}`);
    }

    requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
        showSubmitButtons('txt2img', true);
@@ -171,7 +175,11 @@ function submit_img2img() {
    showSubmitButtons('img2img', false);

    var id = randomId();
    try {
        localStorage.setItem("img2img_task_id", id);
    } catch (e) {
        console.warn(`Failed to save img2img task id to localStorage: ${e}`);
    }

    requestProgress(id, gradioApp().getElementById('img2img_gallery_container'), gradioApp().getElementById('img2img_gallery'), function() {
        showSubmitButtons('img2img', true);
@@ -191,8 +199,6 @@ function restoreProgressTxt2img() {
    showRestoreProgressButton("txt2img", false);
    var id = localStorage.getItem("txt2img_task_id");

    id = localStorage.getItem("txt2img_task_id");

    if (id) {
        requestProgress(id, gradioApp().getElementById('txt2img_gallery_container'), gradioApp().getElementById('txt2img_gallery'), function() {
            showSubmitButtons('txt2img', true);
+3 −1
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ import html
import threading
import time

from modules import shared, progress, errors
from modules import shared, progress, errors, devices

queue_lock = threading.Lock()

@@ -75,6 +75,8 @@ def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
            error_message = f'{type(e).__name__}: {e}'
            res = extra_outputs_array + [f"<div class='error'>{html.escape(error_message)}</div>"]

        devices.torch_gc()

        shared.state.skipped = False
        shared.state.interrupted = False
        shared.state.job_count = 0
+2 −1
Original line number Diff line number Diff line
@@ -14,7 +14,8 @@ def record_exception():
    if exception_records and exception_records[-1] == e:
        return

    exception_records.append((e, tb))
    from modules import sysinfo
    exception_records.append(sysinfo.format_exception(e, tb))

    if len(exception_records) > 5:
        exception_records.pop(0)
+33 −6
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ import json
import torch
import tqdm

from modules import shared, images, sd_models, sd_vae, sd_models_config
from modules import shared, images, sd_models, sd_vae, sd_models_config, errors
from modules.ui_common import plaintext_to_html
import gradio as gr
import safetensors.torch
@@ -72,7 +72,20 @@ def to_half(tensor, enable):
    return tensor


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata):
def read_metadata(primary_model_name, secondary_model_name, tertiary_model_name):
    metadata = {}

    for checkpoint_name in [primary_model_name, secondary_model_name, tertiary_model_name]:
        checkpoint_info = sd_models.checkpoints_list.get(checkpoint_name, None)
        if checkpoint_info is None:
            continue

        metadata.update(checkpoint_info.metadata)

    return json.dumps(metadata, indent=4, ensure_ascii=False)


def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_model_name, interp_method, multiplier, save_as_half, custom_name, checkpoint_format, config_source, bake_in_vae, discard_weights, save_metadata, add_merge_recipe, copy_metadata_fields, metadata_json):
    shared.state.begin(job="model-merge")

    def fail(message):
@@ -241,11 +254,25 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
    shared.state.textinfo = "Saving"
    print(f"Saving to {output_modelname}...")

    metadata = None
    metadata = {}

    if save_metadata and copy_metadata_fields:
        if primary_model_info:
            metadata.update(primary_model_info.metadata)
        if secondary_model_info:
            metadata.update(secondary_model_info.metadata)
        if tertiary_model_info:
            metadata.update(tertiary_model_info.metadata)

    if save_metadata:
        metadata = {"format": "pt"}
        try:
            metadata.update(json.loads(metadata_json))
        except Exception as e:
            errors.display(e, "readin metadata from json")

        metadata["format"] = "pt"

    if save_metadata and add_merge_recipe:
        merge_recipe = {
            "type": "webui", # indicate this model was merged with webui's built-in merger
            "primary_model_hash": primary_model_info.sha256,
@@ -261,7 +288,6 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
            "is_inpainting": result_is_inpainting_model,
            "is_instruct_pix2pix": result_is_instruct_pix2pix_model
        }
        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)

        sd_merge_models = {}

@@ -281,11 +307,12 @@ def run_modelmerger(id_task, primary_model_name, secondary_model_name, tertiary_
        if tertiary_model_info:
            add_model_metadata(tertiary_model_info)

        metadata["sd_merge_recipe"] = json.dumps(merge_recipe)
        metadata["sd_merge_models"] = json.dumps(sd_merge_models)

    _, extension = os.path.splitext(output_modelname)
    if extension.lower() == ".safetensors":
        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata)
        safetensors.torch.save_file(theta_0, output_modelname, metadata=metadata if len(metadata)>0 else None)
    else:
        torch.save(theta_0, output_modelname)

Loading