Commit b0063827 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

serve images from where they are saved instead of a temporary directory

add an option to choose a different temporary directory in the UI
add an option to cleanup the selected temporary directory at startup
parent b5050ad2
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -524,6 +524,8 @@ def save_image(image, path, basename, seed=None, prompt=None, extension='png', i
    else:
        image.save(fullfn, quality=opts.jpeg_quality)

    image.already_saved_as = fullfn

    target_side_length = 4000
    oversize = image.width > target_side_length or image.height > target_side_length
    if opts.export_for_4chan and (oversize or os.stat(fullfn).st_size > 4 * 1024 * 1024):
+7 −0
Original line number Diff line number Diff line
@@ -16,6 +16,9 @@ import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading
from modules.paths import models_path, script_path, sd_path


demo = None

sd_model_file = os.path.join(script_path, 'model.ckpt')
default_sd_model_file = sd_model_file
parser = argparse.ArgumentParser()
@@ -292,6 +295,10 @@ options_templates.update(options_section(('saving-images', "Saving images/grids"
    "use_original_name_batch": OptionInfo(False, "Use original name for output filename during batch process in extras tab"),
    "save_selected_only": OptionInfo(True, "When using 'Save' button, only save a single selected image"),
    "do_not_add_watermark": OptionInfo(False, "Do not add watermark to images"),

    "temp_dir":  OptionInfo("", "Directory for temporary images; leave empty for default"),
    "clean_temp_dir_at_start": OptionInfo(False, "Cleanup non-default temporary directory when starting webui"),

}))

options_templates.update(options_section(('saving-paths', "Paths for saving"), {
+0 −16
Original line number Diff line number Diff line
@@ -157,22 +157,6 @@ def save_files(js_data, images, do_make_zip, index):

    return gr.File.update(value=fullfns, visible=True), '', '', plaintext_to_html(f"Saved: {filenames[0]}")

def save_pil_to_file(pil_image, dir=None):
    use_metadata = False
    metadata = PngImagePlugin.PngInfo()
    for key, value in pil_image.info.items():
        if isinstance(key, str) and isinstance(value, str):
            metadata.add_text(key, value)
            use_metadata = True

    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
    return file_obj


# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file


def wrap_gradio_call(func, extra_outputs=None, add_stats=False):
    def f(*args, extra_outputs_array=extra_outputs, **kwargs):

modules/ui_tempdir.py

0 → 100644
+62 −0
Original line number Diff line number Diff line
import os
import tempfile
from collections import namedtuple

import gradio as gr

from PIL import PngImagePlugin

from modules import shared


Savedfile = namedtuple("Savedfile", ["name"])


def save_pil_to_file(pil_image, dir=None):
    already_saved_as = getattr(pil_image, 'already_saved_as', None)
    if already_saved_as:
        shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(os.path.dirname(already_saved_as))}
        file_obj = Savedfile(already_saved_as)
        return file_obj

    if shared.opts.temp_dir != "":
        dir = shared.opts.temp_dir

    use_metadata = False
    metadata = PngImagePlugin.PngInfo()
    for key, value in pil_image.info.items():
        if isinstance(key, str) and isinstance(value, str):
            metadata.add_text(key, value)
            use_metadata = True

    file_obj = tempfile.NamedTemporaryFile(delete=False, suffix=".png", dir=dir)
    pil_image.save(file_obj, pnginfo=(metadata if use_metadata else None))
    return file_obj


# override save to file function so that it also writes PNG info
gr.processing_utils.save_pil_to_file = save_pil_to_file


def on_tmpdir_changed():
    if shared.opts.temp_dir == "" or shared.demo is None:
        return

    os.makedirs(shared.opts.temp_dir, exist_ok=True)

    shared.demo.temp_dirs = shared.demo.temp_dirs | {os.path.abspath(shared.opts.temp_dir)}


def cleanup_tmpdr():
    temp_dir = shared.opts.temp_dir
    if temp_dir == "" or not os.path.isdir(temp_dir):
        return

    for root, dirs, files in os.walk(temp_dir, topdown=False):
        for name in files:
            _, extension = os.path.splitext(name)
            if extension != ".png":
                continue

            filename = os.path.join(root, name)
            os.remove(filename)
+11 −5
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from fastapi.middleware.gzip import GZipMiddleware

from modules.paths import script_path

from modules import shared, devices, sd_samplers, upscaler, extensions, localization
from modules import shared, devices, sd_samplers, upscaler, extensions, localization, ui_tempdir
import modules.codeformer_model as codeformer
import modules.extras
import modules.face_restoration
@@ -31,12 +31,14 @@ from modules import modelloader
from modules.shared import cmd_opts
import modules.hypernetworks.hypernetwork


queue_lock = threading.Lock()
if cmd_opts.server_name:
    server_name = cmd_opts.server_name
else:
    server_name = "0.0.0.0" if cmd_opts.listen else None


def wrap_queued_call(func):
    def f(*args, **kwargs):
        with queue_lock:
@@ -87,6 +89,7 @@ def initialize():
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("sd_hypernetwork", wrap_queued_call(lambda: shared.reload_hypernetworks()))
    shared.opts.onchange("sd_hypernetwork_strength", modules.hypernetworks.hypernetwork.apply_strength)
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)

    if cmd_opts.tls_keyfile is not None and cmd_opts.tls_keyfile is not None:

@@ -149,9 +152,12 @@ def webui():
    initialize()

    while 1:
        demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)
        if shared.opts.clean_temp_dir_at_start:
            ui_tempdir.cleanup_tmpdr()

        shared.demo = modules.ui.create_ui(wrap_gradio_gpu_call=wrap_gradio_gpu_call)

        app, local_url, share_url = demo.launch(
        app, local_url, share_url = shared.demo.launch(
            share=cmd_opts.share,
            server_name=server_name,
            server_port=cmd_opts.port,
@@ -178,9 +184,9 @@ def webui():
        if launch_api:
            create_api(app)

        modules.script_callbacks.app_started_callback(demo, app)
        modules.script_callbacks.app_started_callback(shared.demo, app)

        wait_on_server(demo)
        wait_on_server(shared.demo)

        sd_samplers.set_samplers()