Commit 45601766 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

added VAE selection to checkpoint user metadata

parent 31a9966b
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
import json
import os
import re
from collections import defaultdict

@@ -177,3 +179,20 @@ def parse_prompts(prompts):

    return res, extra_data


def get_user_metadata(filename):
    if filename is None:
        return {}

    basename, ext = os.path.splitext(filename)
    metadata_filename = basename + '.json'

    metadata = {}
    try:
        if os.path.isfile(metadata_filename):
            with open(metadata_filename, "r", encoding="utf8") as file:
                metadata = json.load(file)
    except Exception as e:
        errors.display(e, f"reading extra network user metadata from {metadata_filename}")

    return metadata
+12 −1
Original line number Diff line number Diff line
import os
import collections
from modules import paths, shared, devices, script_callbacks, sd_models
from modules import paths, shared, devices, script_callbacks, sd_models, extra_networks
import glob
from copy import deepcopy

@@ -16,6 +16,7 @@ checkpoint_info = None

checkpoints_loaded = collections.OrderedDict()


def get_base_vae(model):
    if base_vae is not None and checkpoint_info == model.sd_checkpoint_info and model:
        return base_vae
@@ -100,6 +101,16 @@ def resolve_vae(checkpoint_file):
    if shared.cmd_opts.vae_path is not None:
        return shared.cmd_opts.vae_path, 'from commandline argument'

    metadata = extra_networks.get_user_metadata(checkpoint_file)
    vae_metadata = metadata.get("vae", None)
    if vae_metadata is not None and vae_metadata != "Automatic":
        if vae_metadata == "None":
            return None, None

        vae_from_metadata = vae_dict.get(vae_metadata, None)
        if vae_from_metadata is not None:
            return vae_from_metadata, "from user metadata"

    is_automatic = shared.opts.sd_vae in {"Automatic", "auto"}  # "auto" for people with old config

    vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
+2 −11
Original line number Diff line number Diff line
@@ -2,7 +2,7 @@ import os.path
import urllib.parse
from pathlib import Path

from modules import shared, ui_extra_networks_user_metadata, errors
from modules import shared, ui_extra_networks_user_metadata, errors, extra_networks
from modules.images import read_info_from_image, save_image_with_geninfo
from modules.ui import up_down_symbol
import gradio as gr
@@ -101,16 +101,7 @@ class ExtraNetworksPage:

    def read_user_metadata(self, item):
        filename = item.get("filename", None)
        basename, ext = os.path.splitext(filename)
        metadata_filename = basename + '.json'

        metadata = {}
        try:
            if os.path.isfile(metadata_filename):
                with open(metadata_filename, "r", encoding="utf8") as file:
                    metadata = json.load(file)
        except Exception as e:
            errors.display(e, f"reading extra network user metadata from {metadata_filename}")
        metadata = extra_networks.get_user_metadata(filename)

        desc = metadata.get("description", None)
        if desc is not None:
+3 −0
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ import os

from modules import shared, ui_extra_networks, sd_models
from modules.ui_extra_networks import quote_js
from modules.ui_extra_networks_checkpoints_user_metadata import CheckpointUserMetadataEditor


class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
@@ -34,3 +35,5 @@ class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
    def allowed_directories_for_previews(self):
        return [v for v in [shared.cmd_opts.ckpt_dir, sd_models.model_path] if v is not None]

    def create_user_metadata_editor(self, ui, tabname):
        return CheckpointUserMetadataEditor(ui, tabname, self)
+60 −0
Original line number Diff line number Diff line
import gradio as gr

from modules import ui_extra_networks_user_metadata, sd_vae
from modules.ui_common import create_refresh_button


class CheckpointUserMetadataEditor(ui_extra_networks_user_metadata.UserMetadataEditor):
    def __init__(self, ui, tabname, page):
        super().__init__(ui, tabname, page)

        self.select_vae = None

    def save_user_metadata(self, name, desc, notes, vae):
        user_metadata = self.get_user_metadata(name)
        user_metadata["description"] = desc
        user_metadata["notes"] = notes
        user_metadata["vae"] = vae

        self.write_user_metadata(name, user_metadata)

    def put_values_into_components(self, name):
        user_metadata = self.get_user_metadata(name)
        values = super().put_values_into_components(name)

        return [
            *values[0:5],
            user_metadata.get('vae', ''),
        ]

    def create_editor(self):
        self.create_default_editor_elems()

        with gr.Row():
            self.select_vae = gr.Dropdown(choices=["Automatic", "None"] + list(sd_vae.vae_dict), value="None", label="Preferred VAE", elem_id="checpoint_edit_user_metadata_preferred_vae")
            create_refresh_button(self.select_vae, sd_vae.refresh_vae_list, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, "checpoint_edit_user_metadata_refresh_preferred_vae")

        self.edit_notes = gr.TextArea(label='Notes', lines=4)

        self.create_default_buttons()

        viewed_components = [
            self.edit_name,
            self.edit_description,
            self.html_filedata,
            self.html_preview,
            self.edit_notes,
            self.select_vae,
        ]

        self.button_edit\
            .click(fn=self.put_values_into_components, inputs=[self.edit_name_input], outputs=viewed_components)\
            .then(fn=lambda: gr.update(visible=True), inputs=[], outputs=[self.box])

        edited_components = [
            self.edit_description,
            self.edit_notes,
            self.select_vae,
        ]

        self.setup_save_handler(self.button_save, self.save_user_metadata, edited_components)