Commit 8aa87c56 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add UI to edit defaults

allow setting defaults for elements in extensions' tabs
fix a problem with ESRGAN upscalers disappearing after UI reload
implicit change: HTML element id for train tab from tab_ti to tab_train (will this break things?)
parent 5abecea3
Loading
Loading
Loading
Loading
+10 −17
Original line number Diff line number Diff line
@@ -116,20 +116,6 @@ def move_files(src_path: str, dest_path: str, ext_filter: str = None):
        pass


builtin_upscaler_classes = []
forbidden_upscaler_classes = set()


def list_builtin_upscalers():
    builtin_upscaler_classes.clear()
    builtin_upscaler_classes.extend(Upscaler.__subclasses__())

def forbid_loaded_nonbuiltin_upscalers():
    for cls in Upscaler.__subclasses__():
        if cls not in builtin_upscaler_classes:
            forbidden_upscaler_classes.add(cls)


def load_upscalers():
    # We can only do this 'magic' method to dynamically load upscalers if they are referenced,
    # so we'll try to import any _model.py files before looking in __subclasses__
@@ -145,10 +131,17 @@ def load_upscalers():

    datas = []
    commandline_options = vars(shared.cmd_opts)
    for cls in Upscaler.__subclasses__():
        if cls in forbidden_upscaler_classes:
            continue

    # some of upscaler classes will not go away after reloading their modules, and we'll end
    # up with two copies of those classes. The newest copy will always be the last in the list,
    # so we go from end to beginning and ignore duplicates
    used_classes = {}
    for cls in reversed(Upscaler.__subclasses__()):
        classname = str(cls)
        if classname not in used_classes:
            used_classes[classname] = cls

    for cls in reversed(used_classes.values()):
        name = cls.__name__
        cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
        scaler = cls(commandline_options.get(cmd_name, None))
+19 −103
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ import numpy as np
from PIL import Image, PngImagePlugin  # noqa: F401
from modules.call_queue import wrap_gradio_gpu_call, wrap_queued_call, wrap_gradio_call

from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress
from modules import sd_hijack, sd_models, localization, script_callbacks, ui_extensions, deepbooru, sd_vae, extra_networks, ui_common, ui_postprocessing, progress, ui_loadsave
from modules.ui_components import FormRow, FormGroup, ToolButton, FormHTML
from modules.paths import script_path, data_path

@@ -86,16 +86,6 @@ def send_gradio_gallery_to_image(x):
        return None
    return image_from_url_text(x[0])

def visit(x, func, path=""):
    if hasattr(x, 'children'):
        if isinstance(x, gr.Tabs) and x.elem_id is not None:
            # Tabs element can't have a label, have to use elem_id instead
            func(f"{path}/Tabs@{x.elem_id}", x)
        for c in x.children:
            visit(c, func, path)
    elif x.label is not None:
        func(f"{path}/{x.label}", x)


def add_style(name: str, prompt: str, negative_prompt: str):
    if name is None:
@@ -1471,6 +1461,8 @@ def create_ui():

        return res

    loadsave = ui_loadsave.UiLoadsave(cmd_opts.ui_config_file)

    components = []
    component_dict = {}
    shared.settings_components = component_dict
@@ -1558,6 +1550,9 @@ def create_ui():
                current_row.__exit__()
                current_tab.__exit__()

            with gr.TabItem("Defaults", id="defaults", elem_id="settings_tab_defaults"):
                loadsave.create_ui()

            with gr.TabItem("Actions", id="actions", elem_id="settings_tab_actions"):
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
@@ -1631,7 +1626,7 @@ def create_ui():
        (extras_interface, "Extras", "extras"),
        (pnginfo_interface, "PNG Info", "pnginfo"),
        (modelmerger_interface, "Checkpoint Merger", "modelmerger"),
        (train_interface, "Train", "ti"),
        (train_interface, "Train", "train"),
    ]

    interfaces += script_callbacks.ui_tabs_callback()
@@ -1659,6 +1654,16 @@ def create_ui():
                with gr.TabItem(label, id=ifid, elem_id=f"tab_{ifid}"):
                    interface.render()

            for interface, _label, ifid in interfaces:
                if ifid in ["extensions", "settings"]:
                    continue

                loadsave.add_block(interface, ifid)

            loadsave.add_component(f"webui/Tabs@{tabs.elem_id}", tabs)

            loadsave.setup_ui()

        if os.path.exists(os.path.join(script_path, "notification.mp3")):
            gr.Audio(interactive=False, value=os.path.join(script_path, "notification.mp3"), elem_id="audio_notification", visible=False)

@@ -1747,97 +1752,8 @@ def create_ui():
            ]
        )

    ui_config_file = cmd_opts.ui_config_file
    ui_settings = {}
    settings_count = len(ui_settings)
    error_loading = False

    try:
        if os.path.exists(ui_config_file):
            with open(ui_config_file, "r", encoding="utf8") as file:
                ui_settings = json.load(file)
    except Exception:
        error_loading = True
        print("Error loading settings:", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)

    def loadsave(path, x):
        def apply_field(obj, field, condition=None, init_field=None):
            key = f"{path}/{field}"

            if getattr(obj, 'custom_script_source', None) is not None:
              key = f"customscript/{obj.custom_script_source}/{key}"

            if getattr(obj, 'do_not_save_to_config', False):
                return

            saved_value = ui_settings.get(key, None)
            if saved_value is None:
                ui_settings[key] = getattr(obj, field)
            elif condition and not condition(saved_value):
                pass

                # this warning is generally not useful;
                # print(f'Warning: Bad ui setting value: {key}: {saved_value}; Default value "{getattr(obj, field)}" will be used instead.')
            else:
                setattr(obj, field, saved_value)
                if init_field is not None:
                    init_field(saved_value)

        if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
            apply_field(x, 'visible')

        if type(x) == gr.Slider:
            apply_field(x, 'value')
            apply_field(x, 'minimum')
            apply_field(x, 'maximum')
            apply_field(x, 'step')

        if type(x) == gr.Radio:
            apply_field(x, 'value', lambda val: val in x.choices)

        if type(x) == gr.Checkbox:
            apply_field(x, 'value')

        if type(x) == gr.Textbox:
            apply_field(x, 'value')

        if type(x) == gr.Number:
            apply_field(x, 'value')

        if type(x) == gr.Dropdown:
            def check_dropdown(val):
                if getattr(x, 'multiselect', False):
                    return all(value in x.choices for value in val)
                else:
                    return val in x.choices

            apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))

        def check_tab_id(tab_id):
            tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
            if type(tab_id) == str:
                tab_ids = [t.id for t in tab_items]
                return tab_id in tab_ids
            elif type(tab_id) == int:
                return tab_id >= 0 and tab_id < len(tab_items)
            else:
                return False

        if type(x) == gr.Tabs:
            apply_field(x, 'selected', check_tab_id)

    visit(txt2img_interface, loadsave, "txt2img")
    visit(img2img_interface, loadsave, "img2img")
    visit(extras_interface, loadsave, "extras")
    visit(modelmerger_interface, loadsave, "modelmerger")
    visit(train_interface, loadsave, "train")

    loadsave(f"webui/Tabs@{tabs.elem_id}", tabs)

    if not error_loading and (not os.path.exists(ui_config_file) or settings_count != len(ui_settings)):
        with open(ui_config_file, "w", encoding="utf8") as file:
            json.dump(ui_settings, file, indent=4)
    loadsave.dump_defaults()
    demo.ui_loadsave = loadsave

    # Required as a workaround for change() event not triggering when loading values from ui-config.json
    interp_description.value = update_interp_description(interp_method.value)

modules/ui_loadsave.py

0 → 100644
+208 −0
Original line number Diff line number Diff line
import json
import os

import gradio as gr

from modules import errors
from modules.ui_components import ToolButton


class UiLoadsave:
    """allows saving and restorig default values for gradio components"""

    def __init__(self, filename):
        self.filename = filename
        self.ui_settings = {}
        self.component_mapping = {}
        self.error_loading = False
        self.finalized_ui = False

        self.ui_defaults_view = None
        self.ui_defaults_apply = None
        self.ui_defaults_review = None

        try:
            if os.path.exists(self.filename):
                self.ui_settings = self.read_from_file()
        except Exception as e:
            self.error_loading = True
            errors.display(e, "loading settings")

    def add_component(self, path, x):
        """adds component to the registry of tracked components"""

        assert not self.finalized_ui

        def apply_field(obj, field, condition=None, init_field=None):
            key = f"{path}/{field}"

            if getattr(obj, 'custom_script_source', None) is not None:
              key = f"customscript/{obj.custom_script_source}/{key}"

            if getattr(obj, 'do_not_save_to_config', False):
                return

            saved_value = self.ui_settings.get(key, None)
            if saved_value is None:
                self.ui_settings[key] = getattr(obj, field)
            elif condition and not condition(saved_value):
                pass
            else:
                setattr(obj, field, saved_value)
                if init_field is not None:
                    init_field(saved_value)

            if field == 'value' and key not in self.component_mapping:
                self.component_mapping[key] = x

        if type(x) in [gr.Slider, gr.Radio, gr.Checkbox, gr.Textbox, gr.Number, gr.Dropdown, ToolButton] and x.visible:
            apply_field(x, 'visible')

        if type(x) == gr.Slider:
            apply_field(x, 'value')
            apply_field(x, 'minimum')
            apply_field(x, 'maximum')
            apply_field(x, 'step')

        if type(x) == gr.Radio:
            apply_field(x, 'value', lambda val: val in x.choices)

        if type(x) == gr.Checkbox:
            apply_field(x, 'value')

        if type(x) == gr.Textbox:
            apply_field(x, 'value')

        if type(x) == gr.Number:
            apply_field(x, 'value')

        if type(x) == gr.Dropdown:
            def check_dropdown(val):
                if getattr(x, 'multiselect', False):
                    return all(value in x.choices for value in val)
                else:
                    return val in x.choices

            apply_field(x, 'value', check_dropdown, getattr(x, 'init_field', None))

        def check_tab_id(tab_id):
            tab_items = list(filter(lambda e: isinstance(e, gr.TabItem), x.children))
            if type(tab_id) == str:
                tab_ids = [t.id for t in tab_items]
                return tab_id in tab_ids
            elif type(tab_id) == int:
                return 0 <= tab_id < len(tab_items)
            else:
                return False

        if type(x) == gr.Tabs:
            apply_field(x, 'selected', check_tab_id)

    def add_block(self, x, path=""):
        """adds all components inside a gradio block x to the registry of tracked components"""

        if hasattr(x, 'children'):
            if isinstance(x, gr.Tabs) and x.elem_id is not None:
                # Tabs element can't have a label, have to use elem_id instead
                self.add_component(f"{path}/Tabs@{x.elem_id}", x)
            for c in x.children:
                self.add_block(c, path)
        elif x.label is not None:
            self.add_component(f"{path}/{x.label}", x)

    def read_from_file(self):
        with open(self.filename, "r", encoding="utf8") as file:
            return json.load(file)

    def write_to_file(self, current_ui_settings):
        with open(self.filename, "w", encoding="utf8") as file:
            json.dump(current_ui_settings, file, indent=4)

    def dump_defaults(self):
        """saves default values to a file unless tjhe file is present and there was an error loading default values at start"""

        if self.error_loading and os.path.exists(self.filename):
            return

        self.write_to_file(self.ui_settings)

    def iter_changes(self, current_ui_settings, values):
        """
        given a dictionary with defaults from a file and current values from gradio elements, returns
        an iterator over tuples of values that are not the same between the file and the current;
        tuple contents are: path, old value, new value
        """

        for (path, component), new_value in zip(self.component_mapping.items(), values):
            old_value = current_ui_settings.get(path)

            choices = getattr(component, 'choices', None)
            if isinstance(new_value, int) and choices:
                if new_value >= len(choices):
                    continue

                new_value = choices[new_value]

            if new_value == old_value:
                continue

            if old_value is None and new_value == '' or new_value == []:
                continue

            yield path, old_value, new_value

    def ui_view(self, *values):
        text = ["<table><thead><tr><th>Path</th><th>Old value</th><th>New value</th></thead><tbody>"]

        for path, old_value, new_value in self.iter_changes(self.read_from_file(), values):
            if old_value is None:
                old_value = "<span class='ui-defaults-none'>None</span>"

            text.append(f"<tr><td>{path}</td><td>{old_value}</td><td>{new_value}</td></tr>")

        if len(text) == 1:
            text.append("<tr><td colspan=3>No changes</td></tr>")

        text.append("</tbody>")
        return "".join(text)

    def ui_apply(self, *values):
        num_changed = 0

        current_ui_settings = self.read_from_file()

        for path, _, new_value in self.iter_changes(current_ui_settings.copy(), values):
            num_changed += 1
            current_ui_settings[path] = new_value

        if num_changed == 0:
            return "No changes."

        self.write_to_file(current_ui_settings)

        return f"Wrote {num_changed} changes."

    def create_ui(self):
        """creates ui elements for editing defaults UI, without adding any logic to them"""

        gr.HTML(
            f"This page allows you to change default values in UI elements on other tabs.<br />"
            f"Make your changes, press 'View changes' to review the changed default values,<br />"
            f"then press 'Apply' to write them to {self.filename}.<br />"
            f"New defaults will apply after you restart the UI.<br />"
        )

        with gr.Row():
            self.ui_defaults_view = gr.Button(value='View changes', elem_id="ui_defaults_view", variant="secondary")
            self.ui_defaults_apply = gr.Button(value='Apply', elem_id="ui_defaults_apply", variant="primary")

        self.ui_defaults_review = gr.HTML("")

    def setup_ui(self):
        """adds logic to elements created with create_ui; all add_block class must be made before this"""

        assert not self.finalized_ui
        self.finalized_ui = True

        self.ui_defaults_view.click(fn=self.ui_view, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
        self.ui_defaults_apply.click(fn=self.ui_apply, inputs=list(self.component_mapping.values()), outputs=[self.ui_defaults_review])
+4 −0
Original line number Diff line number Diff line
@@ -414,6 +414,10 @@ table.settings-value-table td{
    max-width: 36em;
}

.ui-defaults-none{
    color: #aaa !important;
}

/* live preview */
.progressDiv{
    position: relative;
+1 −5
Original line number Diff line number Diff line
@@ -181,14 +181,11 @@ def initialize():
    gfpgan.setup_model(cmd_opts.gfpgan_models_path)
    startup_timer.record("setup gfpgan")

    modelloader.list_builtin_upscalers()
    startup_timer.record("list builtin upscalers")

    modules.scripts.load_scripts()
    startup_timer.record("load scripts")

    modelloader.load_upscalers()
    #startup_timer.record("load upscalers") #Is this necessary? I don't know.
    startup_timer.record("load upscalers")

    modules.sd_vae.refresh_vae_list()
    startup_timer.record("refresh VAE")
@@ -388,7 +385,6 @@ def webui():

        localization.list_localizations(cmd_opts.localizations_dir)

        modelloader.forbid_loaded_nonbuiltin_upscalers()
        modules.scripts.reload_scripts()
        startup_timer.record("load scripts")