Unverified Commit 1fba573d authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #3874 from cobryan05/extra_tweak

Extras Tab - Option to upscale before face fix, caching improvements
parents 2338ed95 d8b36614
Loading
Loading
Loading
Loading
+121 −55
Original line number Diff line number Diff line
from __future__ import annotations
import math
import os

@@ -7,6 +8,10 @@ from PIL import Image
import torch
import tqdm

from typing import Callable, List, OrderedDict, Tuple
from functools import partial
from dataclasses import dataclass

from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
@@ -17,10 +22,38 @@ import piexif.helper
import gradio as gr


cached_images = {}
class LruCache(OrderedDict):
    @dataclass(frozen=True)
    class Key:
        image_hash: int
        info_hash: int
        args_hash: int

    @dataclass
    class Value:
        image: Image.Image
        info: str

    def __init__(self, max_size: int = 5, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self._max_size = max_size

    def get(self, key: LruCache.Key) -> LruCache.Value:
        ret = super().get(key)
        if ret is not None:
            self.move_to_end(key)  # Move to end of eviction list
        return ret

    def put(self, key: LruCache.Key, value: LruCache.Value) -> None:
        self[key] = value
        while len(self) > self._max_size:
            self.popitem(last=False)


def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
cached_images: LruCache = LruCache(max_size=5)


def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool):
    devices.torch_gc()

    imageArr = []
@@ -56,16 +89,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
    else:
        outpath = opts.outdir_samples or opts.outdir_extras_samples

    # Extra operation definitions

    for image, image_name in zip(imageArr, imageNameArr):
        if image is None:
            return outputs, "Please select an input image.", ''
        existing_pnginfo = image.info or {}

        image = image.convert("RGB")
        info = ""

        if gfpgan_visibility > 0:
    def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
        res = Image.fromarray(restored_img)

@@ -73,9 +99,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
            res = Image.blend(image, res, gfpgan_visibility)

        info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
            image = res
        return (res, info)

        if codeformer_visibility > 0:
    def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
        res = Image.fromarray(restored_img)

@@ -83,44 +109,81 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
            res = Image.blend(image, res, codeformer_visibility)

        info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
            image = res
        return (res, info)

    def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
        upscaler = shared.sd_upscalers[scaler_index]
        res = upscaler.scaler.upscale(image, resize, upscaler.data_path)
        if mode == 1 and crop:
            cropped = Image.new("RGB", (resize_w, resize_h))
            cropped.paste(res, box=(resize_w // 2 - res.width // 2, resize_h // 2 - res.height // 2))
            res = cropped
        return res

    def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
        nonlocal upscaling_resize
        if resize_mode == 1:
            upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
            crop_info = " (crop)" if upscaling_crop else ""
            info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
        return (image, info)

    @dataclass
    class UpscaleParams:
        upscaler_idx: int
        blend_alpha: float

    def run_upscalers_blend(params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        blended_result: Image.Image = None
        for upscaler in params:
            upscale_args = (upscaler.upscaler_idx, upscaling_resize, resize_mode,
                            upscaling_resize_w, upscaling_resize_h, upscaling_crop)
            cache_key = LruCache.Key(image_hash=hash(np.array(image.getdata()).tobytes()),
                                     info_hash=hash(info),
                                     args_hash=hash(upscale_args))
            cached_entry = cached_images.get(cache_key)
            if cached_entry is None:
                res = upscale(image, *upscale_args)
                info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
                cached_images.put(cache_key, LruCache.Value(image=res, info=info))
            else:
                res, info = cached_entry.image, cached_entry.info

        if upscaling_resize != 1.0:
            def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
                small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
                pixels = tuple(np.array(small).flatten().tolist())
                key = (resize, scaler_index, image.width, image.height, gfpgan_visibility, codeformer_visibility, codeformer_weight, 
                       resize_mode, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop) + pixels

                c = cached_images.get(key)
                if c is None:
                    upscaler = shared.sd_upscalers[scaler_index]
                    c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
                    if mode == 1 and crop:
                        cropped = Image.new("RGB", (resize_w, resize_h))
                        cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
                        c = cropped
                    cached_images[key] = c
            if blended_result is None:
                blended_result = res
            else:
                blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
        return (blended_result, info)

                return c
    # Build a list of operations to run
    facefix_ops: List[Callable] = []
    facefix_ops += [run_gfpgan] if gfpgan_visibility > 0 else []
    facefix_ops += [run_codeformer] if codeformer_visibility > 0 else []

            info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
            res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
    upscale_ops: List[Callable] = []
    upscale_ops += [run_prepare_crop] if resize_mode == 1 else []

    if upscaling_resize != 0:
        step_params: List[UpscaleParams] = []
        step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_1, blend_alpha=1.0))
        if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
                res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
                info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
                res = Image.blend(res, res2, extras_upscaler_2_visibility)
            step_params.append(UpscaleParams(upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility))

            image = res
        upscale_ops.append(partial(run_upscalers_blend, step_params))

        while len(cached_images) > 2:
            del cached_images[next(iter(cached_images.keys()))]
    extras_ops: List[Callable] = (upscale_ops + facefix_ops) if upscale_first else (facefix_ops + upscale_ops)

    for image, image_name in zip(imageArr, imageNameArr):
        if image is None:
            return outputs, "Please select an input image.", ''
        existing_pnginfo = image.info or {}

        image = image.convert("RGB")
        info = ""
        # Run each operation on each image
        for op in extras_ops:
            image, info = op(image, info)

        if opts.use_original_name_batch and image_name != None:
            basename = os.path.splitext(os.path.basename(image_name))[0]
@@ -141,6 +204,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_

    return outputs, plaintext_to_html(info), ''

def clear_cache():
    cached_images.clear()


def run_pnginfo(image):
    if image is None:
+9 −0
Original line number Diff line number Diff line
@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call):
                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)

                with gr.Group():
                    upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)

                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')

            with gr.Column(variant='panel'):
@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call):
                extras_upscaler_1,
                extras_upscaler_2,
                extras_upscaler_2_visibility,
                upscale_before_face_fix,
            ],
            outputs=[
                result_images,
@@ -1174,6 +1178,11 @@ def create_ui(wrap_gradio_gpu_call):
            outputs=[init_img_with_mask],
        )

        extras_image.change(
            fn=modules.extras.clear_cache,
            inputs=[], outputs=[]
        )

    with gr.Blocks(analytics_enabled=False) as pnginfo_interface:
        with gr.Row().style(equal_height=False):
            with gr.Column(variant='panel'):