Commit b5230197 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

rework extras tab to use script system

parent 68303c96
Loading
Loading
Loading
Loading
+0 −5
Original line number Diff line number Diff line
@@ -104,11 +104,6 @@ function create_tab_index_args(tabId, args){
    return res
}

function get_extras_tab_index(){
    const [,,...args] = [...arguments]
    return [get_tab_index('mode_extras'), get_tab_index('extras_resize_mode'), ...args]
}

function get_img2img_tab_index() {
    let res = args_to_array(arguments)
    res.splice(-2)
+5 −8
Original line number Diff line number Diff line
@@ -11,10 +11,9 @@ from fastapi.security import HTTPBasic, HTTPBasicCredentials
from secrets import compare_digest

import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing
from modules.api.models import *
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
from modules.extras import run_extras
from modules.textual_inversion.textual_inversion import create_embedding, train_embedding
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
@@ -45,10 +44,8 @@ def validate_sampler_name(name):

def setUpscalers(req: dict):
    reqDict = vars(req)
    reqDict['extras_upscaler_1'] = upscaler_to_index(req.upscaler_1)
    reqDict['extras_upscaler_2'] = upscaler_to_index(req.upscaler_2)
    reqDict.pop('upscaler_1')
    reqDict.pop('upscaler_2')
    reqDict['extras_upscaler_1'] = reqDict.pop('upscaler_1', None)
    reqDict['extras_upscaler_2'] = reqDict.pop('upscaler_2', None)
    return reqDict

def decode_base64_to_image(encoding):
@@ -244,7 +241,7 @@ class Api:
        reqDict['image'] = decode_base64_to_image(reqDict['image'])

        with self.queue_lock:
            result = run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)
            result = postprocessing.run_extras(extras_mode=0, image_folder="", input_dir="", output_dir="", save_output=False, **reqDict)

        return ExtrasSingleImageResponse(image=encode_pil_to_base64(result[0][0]), html_info=result[1])

@@ -260,7 +257,7 @@ class Api:
        reqDict.pop('imageList')

        with self.queue_lock:
            result = run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)
            result = postprocessing.run_extras(extras_mode=1, image="", input_dir="", output_dir="", save_output=False, **reqDict)

        return ExtrasBatchImagesResponse(images=list(map(encode_pil_to_base64, result[0])), html_info=result[1])

+60 −176
Original line number Diff line number Diff line
from __future__ import annotations
import os

import numpy as np
from PIL import Image

from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass

from modules import shared, images, devices, ui_components
from modules import shared, images, devices, scripts, scripts_postprocessing, ui_common, generation_parameters_copypaste
from modules.shared import opts
import modules.gfpgan_model
import modules.codeformer_model


class LruCache(OrderedDict):
    @dataclass(frozen=True)
    class Key:
        image_hash: int
        info_hash: int
        args_hash: int

    @dataclass
    class Value:
        image: Image.Image
        info: str

    def __init__(self, max_size: int = 5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._max_size = max_size

    def get(self, key: LruCache.Key) -> LruCache.Value:
        ret = super().get(key)
        if ret is not None:
            self.move_to_end(key)  # Move to end of eviction list
        return ret

    def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
        self[key] = value
        while len(self) > self._max_size:
            self.popitem(last=False)


cached_images: LruCache = LruCache(max_size=5)


def run_postprocessing(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
def run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output: bool = True):
    devices.torch_gc()

    shared.state.begin()
    shared.state.job = 'extras'

    imageArr = []
    # Also keep track of original file names
    imageNameArr = []
    image_data = []
    image_names = []
    outputs = []

    if extras_mode == 1:
        #convert file to pillow image
        for img in image_folder:
            image = Image.open(img)
            imageArr.append(image)
            imageNameArr.append(os.path.splitext(img.orig_name)[0])
            image_data.append(image)
            image_names.append(os.path.splitext(img.orig_name)[0])
    elif extras_mode == 2:
        assert not shared.cmd_opts.hide_ui_dir_config, '--hide-ui-dir-config option must be disabled'
        assert input_dir, 'input directory not selected'

        if input_dir == '':
            return outputs, "Please select an input directory.", ''
        image_list = shared.listfiles(input_dir)
        for img in image_list:
        for filename in image_list:
            try:
                image = Image.open(img)
                image = Image.open(filename)
            except Exception:
                continue
            imageArr.append(image)
            imageNameArr.append(img)
            image_data.append(image)
            image_names.append(filename)
    else:
        imageArr.append(image)
        imageNameArr.append(None)
        assert image, 'image not selected'

        image_data.append(image)
        image_names.append(None)

    if extras_mode == 2 and output_dir != '':
        outpath = output_dir
    else:
        outpath = opts.outdir_samples or opts.outdir_extras_samples

    # Extra operation definitions

    def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        shared.state.job = 'extras-gfpgan'
        restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
        res = Image.fromarray(restored_img)

        if gfpgan_visibility < 1.0:
            res = Image.blend(image, res, gfpgan_visibility)

        info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
        return (res, info)

    def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        shared.state.job = 'extras-codeformer'
        restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
        res = Image.fromarray(restored_img)

        if codeformer_visibility < 1.0:
            res = Image.blend(image, res, codeformer_visibility)

        info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
        return (res, info)

    def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
        shared.state.job = 'extras-upscale'
        upscaler = shared.sd_upscalers[scaler_index]
        res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
        if mode == 1 and crop:
            cropped = Image.new("RGB", (resize_w, resize_h))
            cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
            res = cropped
        return res

    def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
        nonlocal upscaling_resize
        if resize_mode == 1:
            upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
            crop_info = " (crop)" if upscaling_crop else ""
            info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
        return (image, info)

    @dataclass
    class UpscaleParams:
        upscaler_idx: int
        blend_alpha: float

    def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        blended_result: Image.Image = None
        image_hash: str = hash(np.array(image.getdata()).tobytes())
        for upscaler in params:
            upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
                            upscaling_resize_w, upscaling_resize_h, upscaling_crop)
            cache_key = LruCache.Key(image_hash=image_hash,
                                     info_hash=hash(info),
                                     args_hash=hash(upscale_args))
            cached_entry = cached_images.get(cache_key)
            if cached_entry is None:
                res = upscale(image, *upscale_args)
                info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
                cached_images.put(cache_key, LruCache.Value(image=res, info=info))
            else:
                res, info = cached_entry.image, cached_entry.info

            if blended_result is None:
                blended_result = res
            else:
                blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
        return (blended_result, info)

    # Build a list of operations to run
    facefix_ops: List[Callable] = []
    facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
    facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []

    upscale_ops: List[Callable] = []
    upscale_ops += [run_prepare_crop] if resize_mode == 1 else []

    if upscaling_resize != 0:
        step_params: List[UpscaleParams] = []
        step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
        if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
            step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))

        upscale_ops.append(partial(run_upscalers_blend, step_params))

    extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)

    for image, image_name in zip(imageArr, imageNameArr):
        if image is None:
            return outputs, "Please select an input image.", ''
    infotext = ''

        shared.state.textinfo = f'Processing image {image_name}'
    for image, name in zip(image_data, image_names):
        shared.state.textinfo = name

        existing_pnginfo = image.info or {}

        image = image.convert("RGB")
        info = ""
        # Run each operation on each image
        for op in extras_ops:
            image, info = op(image, info)
        pp = scripts_postprocessing.PostprocessedImage(image.convert("RGB"))

        if opts.use_original_name_batch and image_name is not None:
            basename = os.path.splitext(os.path.basename(image_name))[0]
        scripts.scripts_postproc.run(pp, args)

        if opts.use_original_name_batch and name is not None:
            basename = os.path.splitext(os.path.basename(name))[0]
        else:
            basename = ''

        if opts.enable_pnginfo: # append info before save
            image.info = existing_pnginfo
            image.info["extras"] = info
        infotext = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in pp.info.items() if v is not None])

        if save_output:
            # Add upscaler name as a suffix.
            suffix = f"-{shared.sd_upscalers[extras_upscaler_1].name}" if shared.opts.use_upscaler_name_as_suffix else ""
            # Add second upscaler if applicable.
            if suffix and extras_upscaler_2 and extras_upscaler_2_visibility:
                suffix += f"-{shared.sd_upscalers[extras_upscaler_2].name}"
        if opts.enable_pnginfo:
            pp.image.info = existing_pnginfo
            pp.image.info["postprocessing"] = infotext

            images.save_image(image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=info, short_filename=True,
                            no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None, suffix=suffix)
        if save_output:
            images.save_image(pp.image, path=outpath, basename=basename, seed=None, prompt=None, extension=opts.samples_format, info=pp.info, short_filename=True, no_prompt=True, grid=False, pnginfo_section_name="extras", existing_info=existing_pnginfo, forced_filename=None)

        if extras_mode != 2 or show_extras_results:
            outputs.append(image)
            outputs.append(pp.image)

    devices.torch_gc()

    return outputs, ui_components.plaintext_to_html(info), ''


def clear_cache():
    cached_images.clear()
    return outputs, ui_common.plaintext_to_html(infotext), ''


def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool, save_output: bool = True):
    """old handler for API"""

    args = scripts.scripts_postproc.create_args_for_run({
        "Upscale": {
            "upscale_mode": resize_mode,
            "upscale_by": upscaling_resize,
            "upscale_to_width": upscaling_resize_w,
            "upscale_to_height": upscaling_resize_h,
            "upscale_crop": upscaling_crop,
            "upscaler_1_name": extras_upscaler_1,
            "upscaler_2_name": extras_upscaler_2,
            "upscaler_2_visibility": extras_upscaler_2_visibility,
        },
        "GFPGAN": {
            "gfpgan_visibility": gfpgan_visibility,
        },
        "CodeFormer": {
            "codeformer_visibility": codeformer_visibility,
            "codeformer_weight": codeformer_weight,
        },
    })

    return run_postprocessing(extras_mode, image, image_folder, input_dir, output_dir, show_extras_results, *args, save_output=save_output)
+20 −8
Original line number Diff line number Diff line
@@ -7,7 +7,7 @@ from collections import namedtuple
import gradio as gr

from modules.processing import StableDiffusionProcessing
from modules import shared, paths, script_callbacks, extensions, script_loading
from modules import shared, paths, script_callbacks, extensions, script_loading, scripts_postprocessing

AlwaysVisible = object()

@@ -150,8 +150,10 @@ def basedir():
    return current_basedir


scripts_data = []
ScriptFile = namedtuple("ScriptFile", ["basedir", "filename", "path"])

scripts_data = []
postprocessing_scripts_data = []
ScriptClassData = namedtuple("ScriptClassData", ["script_class", "path", "basedir", "module"])


@@ -190,23 +192,31 @@ def list_files_with_name(filename):
def load_scripts():
    global current_basedir
    scripts_data.clear()
    postprocessing_scripts_data.clear()
    script_callbacks.clear_callbacks()

    scripts_list = list_scripts("scripts", ".py")

    syspath = sys.path

    def register_scripts_from_module(module):
        for key, script_class in module.__dict__.items():
            if type(script_class) != type:
                continue

            if issubclass(script_class, Script):
                scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            elif issubclass(script_class, scripts_postprocessing.ScriptPostprocessing):
                postprocessing_scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))

    for scriptfile in sorted(scripts_list):
        try:
            if scriptfile.basedir != paths.script_path:
                sys.path = [scriptfile.basedir] + sys.path
            current_basedir = scriptfile.basedir

            module = script_loading.load_module(scriptfile.path)

            for key, script_class in module.__dict__.items():
                if type(script_class) == type and issubclass(script_class, Script):
                    scripts_data.append(ScriptClassData(script_class, scriptfile.path, scriptfile.basedir, module))
            script_module = script_loading.load_module(scriptfile.path)
            register_scripts_from_module(script_module)

        except Exception:
            print(f"Error loading script: {scriptfile.filename}", file=sys.stderr)
@@ -413,6 +423,7 @@ class ScriptRunner:

scripts_txt2img = ScriptRunner()
scripts_img2img = ScriptRunner()
scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()
scripts_current: ScriptRunner = None


@@ -423,12 +434,13 @@ def reload_script_body_only():


def reload_scripts():
    global scripts_txt2img, scripts_img2img
    global scripts_txt2img, scripts_img2img, scripts_postproc

    load_scripts()

    scripts_txt2img = ScriptRunner()
    scripts_img2img = ScriptRunner()
    scripts_postproc = scripts_postprocessing.ScriptPostprocessingRunner()


def IOComponent_init(self, *args, **kwargs):
+147 −0
Original line number Diff line number Diff line
import os
import gradio as gr

from modules import errors, shared


class PostprocessedImage:
    def __init__(self, image):
        self.image = image
        self.info = {}


class ScriptPostprocessing:
    filename = None
    controls = None
    args_from = None
    args_to = None

    order = 1000
    """scripts will be ordred by this value in postprocessing UI"""

    name = None
    """this function should return the title of the script."""

    group = None
    """A gr.Group component that has all script's UI inside it"""

    def ui(self):
        """
        This function should create gradio UI elements. See https://gradio.app/docs/#components
        The return value should be a dictionary that maps parameter names to components used in processing.
        Values of those components will be passed to process() function.
        """

        pass

    def process(self, pp: PostprocessedImage, **args):
        """
        This function is called to postprocess the image.
        args contains a dictionary with all values returned by components from ui()
        """

        pass

    def image_changed(self):
        pass


def wrap_call(func, filename, funcname, *args, default=None, **kwargs):
    try:
        res = func(*args, **kwargs)
        return res
    except Exception as e:
        errors.display(e, f"calling {filename}/{funcname}")

    return default


class ScriptPostprocessingRunner:
    def __init__(self):
        self.scripts = None
        self.ui_created = False

    def initialize_scripts(self, scripts_data):
        self.scripts = []

        for script_class, path, basedir, script_module in scripts_data:
            script: ScriptPostprocessing = script_class()
            script.filename = path

            self.scripts.append(script)

    def create_script_ui(self, script, inputs):
        script.args_from = len(inputs)
        script.args_to = len(inputs)

        script.controls = wrap_call(script.ui, script.filename, "ui")

        for control in script.controls.values():
            control.custom_script_source = os.path.basename(script.filename)

        inputs += list(script.controls.values())
        script.args_to = len(inputs)

    def scripts_in_preferred_order(self):
        if self.scripts is None:
            import modules.scripts
            self.initialize_scripts(modules.scripts.postprocessing_scripts_data)

        scripts_order = [x.lower().strip() for x in shared.opts.postprocessing_scipts_order.split(",")]

        def script_score(name):
            name = name.lower()
            for i, possible_match in enumerate(scripts_order):
                if possible_match in name:
                    return i

            return len(self.scripts)

        script_scores = {script.name: (script_score(script.name), script.order, script.name, original_index) for original_index, script in enumerate(self.scripts)}

        return sorted(self.scripts, key=lambda x: script_scores[x.name])

    def setup_ui(self):
        inputs = []

        for script in self.scripts_in_preferred_order():
            with gr.Box() as group:
                self.create_script_ui(script, inputs)

            script.group = group

        self.ui_created = True
        return inputs

    def run(self, pp: PostprocessedImage, args):
        for script in self.scripts_in_preferred_order():
            shared.state.job = script.name

            script_args = args[script.args_from:script.args_to]

            process_args = {}
            for (name, component), value in zip(script.controls.items(), script_args):
                process_args[name] = value

            script.process(pp, **process_args)

    def create_args_for_run(self, scripts_args):
        if not self.ui_created:
            with gr.Blocks(analytics_enabled=False):
                self.setup_ui()

        scripts = self.scripts_in_preferred_order()
        args = [None] * max([x.args_to for x in scripts])

        for script in scripts:
            script_args_dict = scripts_args.get(script.name, None)
            if script_args_dict is not None:

                for i, name in enumerate(script.controls):
                    args[script.args_from + i] = script_args_dict.get(name, None)

        return args

    def image_changed(self):
        for script in self.scripts_in_preferred_order():
            script.image_changed()
Loading