Commit 1d8e06d5 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add checkpoints tab for extra networks UI

parent 91c8d0dc
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ class ExtraNetworksPageLora(ui_extra_networks.ExtraNetworksPage):
            preview = None
            for file in previews:
                if os.path.isfile(file):
                    preview = "./file=" + file.replace('\\', '/') + "?mtime=" + str(os.path.getmtime(file))
                    preview = self.link_preview(file)
                    break

            yield {
+7 −0
Original line number Diff line number Diff line
@@ -309,3 +309,10 @@ function updateInput(target){
	Object.defineProperty(e, "target", {value: target})
	target.dispatchEvent(e);
}


var desiredCheckpointName = null;
function selectCheckpoint(name){
    desiredCheckpointName = name;
    gradioApp().getElementById('change_checkpoint').click()
}
+8 −0
Original line number Diff line number Diff line
@@ -1560,6 +1560,14 @@ def create_ui():
                outputs=[component, text_settings],
            )

        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
        button_set_checkpoint.click(
            fn=lambda value, _: run_settings_single(value, key='sd_model_checkpoint'),
            _js="function(v){ var res = desiredCheckpointName; desiredCheckpointName = ''; return [res || v, null]; }",
            inputs=[component_dict['sd_model_checkpoint'], dummy_component],
            outputs=[component_dict['sd_model_checkpoint'], text_settings],
        )

        component_keys = [k for k in opts.data_labels.keys() if k in component_dict]

        def get_settings_values():
+33 −4
Original line number Diff line number Diff line
import os.path
import urllib.parse
from pathlib import Path

from modules import shared
import gradio as gr
@@ -8,12 +10,31 @@ import html
from modules.generation_parameters_copypaste import image_from_url_text

extra_pages = []
allowed_dirs = set()


def register_page(page):
    """registers extra networks page for the UI; recommend doing it in on_before_ui() callback for extensions"""

    extra_pages.append(page)
    allowed_dirs.clear()
    allowed_dirs.update(set(sum([x.allowed_directories_for_previews() for x in extra_pages], [])))


def add_pages_to_demo(app):
    def fetch_file(filename: str = ""):
        from starlette.responses import FileResponse

        if not any([Path(x).resolve() in Path(filename).resolve().parents for x in allowed_dirs]):
            raise ValueError(f"File cannot be fetched: {filename}. Must be in one of directories registered by extra pages.")

        if os.path.splitext(filename)[1].lower() != ".png":
            raise ValueError(f"File cannot be fetched: {filename}. Only png.")

        # would profit from returning 304
        return FileResponse(filename, headers={"Accept-Ranges": "bytes"})

    app.add_api_route("/sd_extra_networks/thumb", fetch_file, methods=["GET"])


class ExtraNetworksPage:
@@ -26,6 +47,9 @@ class ExtraNetworksPage:
    def refresh(self):
        pass

    def link_preview(self, filename):
        return "./sd_extra_networks/thumb?filename=" + urllib.parse.quote(filename.replace('\\', '/')) + "&mtime=" + str(os.path.getmtime(filename))

    def create_html(self, tabname):
        view = shared.opts.extra_networks_default_view
        items_html = ''
@@ -54,13 +78,17 @@ class ExtraNetworksPage:
    def create_html_for_item(self, item, tabname):
        preview = item.get("preview", None)

        onclick = item.get("onclick", None)
        if onclick is None:
            onclick = '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"'

        args = {
            "preview_html": "style='background-image: url(\"" + html.escape(preview) + "\")'" if preview else '',
            "prompt": item["prompt"],
            "prompt": item.get("prompt", None),
            "tabname": json.dumps(tabname),
            "local_preview": json.dumps(item["local_preview"]),
            "name": item["name"],
            "card_clicked": '"' + html.escape(f"""return cardClicked({json.dumps(tabname)}, {item["prompt"]}, {"true" if self.allow_negative_prompt else "false"})""") + '"',
            "card_clicked": onclick,
            "save_card_preview": '"' + html.escape(f"""return saveCardPreview(event, {json.dumps(tabname)}, {json.dumps(item["local_preview"])})""") + '"',
        }

@@ -143,7 +171,7 @@ def path_is_parent(parent_path, child_path):
    parent_path = os.path.abspath(parent_path)
    child_path = os.path.abspath(child_path)

    return os.path.commonpath([parent_path]) == os.path.commonpath([parent_path, child_path])
    return child_path.startswith(parent_path)


def setup_ui(ui, gallery):
@@ -173,7 +201,8 @@ def setup_ui(ui, gallery):

    ui.button_save_preview.click(
        fn=save_preview,
        _js="function(x, y, z){console.log(x, y, z); return [selected_gallery_index(), y, z]}",
        _js="function(x, y, z){return [selected_gallery_index(), y, z]}",
        inputs=[ui.preview_target_filename, gallery, ui.preview_target_filename],
        outputs=[*ui.pages]
    )
+38 −0
Original line number Diff line number Diff line
import html
import json
import os
import urllib.parse

from modules import shared, ui_extra_networks, sd_models


class ExtraNetworksPageCheckpoints(ui_extra_networks.ExtraNetworksPage):
    def __init__(self):
        super().__init__('Checkpoints')

    def refresh(self):
        shared.refresh_checkpoints()

    def list_items(self):
        for name, checkpoint1 in sd_models.checkpoints_list.items():
            checkpoint: sd_models.CheckpointInfo = checkpoint1
            path, ext = os.path.splitext(checkpoint.filename)
            previews = [path + ".png", path + ".preview.png"]

            preview = None
            for file in previews:
                if os.path.isfile(file):
                    preview = self.link_preview(file)
                    break

            yield {
                "name": checkpoint.model_name,
                "filename": path,
                "preview": preview,
                "onclick": '"' + html.escape(f"""return selectCheckpoint({json.dumps(name)})""") + '"',
                "local_preview": path + ".png",
            }

    def allowed_directories_for_previews(self):
        return [shared.cmd_opts.ckpt_dir, sd_models.model_path]
Loading