Commit 26d08193 authored by Chris OBryan's avatar Chris OBryan
Browse files

extras: Add option to run upscaling before face fixing

Face restoration can look much better if ran after upscaling, as it
allows the restoration to fix upscaling artifacts. This patch adds
an option to choose which order to run upscaling/face fixing in.
parent 737eb28f
Loading
Loading
Loading
Loading
+95 −50
Original line number Diff line number Diff line
@@ -7,6 +7,10 @@ from PIL import Image
import torch
import tqdm

from typing import Callable, List, Tuple
from functools import partial
from dataclasses import dataclass

from modules import processing, shared, images, devices, sd_models
from modules.shared import opts
import modules.gfpgan_model
@@ -20,7 +24,7 @@ import gradio as gr
cached_images = {}


def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility):
def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_dir, show_extras_results, gfpgan_visibility, codeformer_visibility, codeformer_weight, upscaling_resize, upscaling_resize_w, upscaling_resize_h, upscaling_crop, extras_upscaler_1, extras_upscaler_2, extras_upscaler_2_visibility, upscale_first: bool ):
    devices.torch_gc()

    imageArr = []
@@ -57,15 +61,8 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
        outpath = opts.outdir_samples or opts.outdir_extras_samples


    for image, image_name in zip(imageArr, imageNameArr):
        if image is None:
            return outputs, "Please select an input image.", ''
        existing_pnginfo = image.info or {}

        image = image.convert("RGB")
        info = ""

        if gfpgan_visibility > 0:
    # Extra operation definitions
    def run_gfpgan(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        restored_img = modules.gfpgan_model.gfpgan_fix_faces(np.array(image, dtype=np.uint8))
        res = Image.fromarray(restored_img)

@@ -73,9 +70,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
            res = Image.blend(image, res, gfpgan_visibility)

        info += f"GFPGAN visibility:{round(gfpgan_visibility, 2)}\n"
            image = res
        return (res, info)

        if codeformer_visibility > 0:
    def run_codeformer(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        restored_img = modules.codeformer_model.codeformer.restore(np.array(image, dtype=np.uint8), w=codeformer_weight)
        res = Image.fromarray(restored_img)

@@ -83,14 +80,9 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
            res = Image.blend(image, res, codeformer_visibility)

        info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
            image = res
        return (res, info)

        if resize_mode == 1:
            upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
            crop_info = " (crop)" if upscaling_crop else ""
            info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"

        if upscaling_resize != 1.0:
    def upscale(image, scaler_index, resize, mode, resize_w, resize_h, crop):
        small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
        pixels = tuple(np.array(small).flatten().tolist())
@@ -106,18 +98,71 @@ def run_extras(extras_mode, resize_mode, image, image_folder, input_dir, output_
                cropped.paste(c, box=(resize_w // 2 - c.width // 2, resize_h // 2 - c.height // 2))
                c = cropped
            cached_images[key] = c

        return c

            info += f"Upscale: {round(upscaling_resize, 3)}, model:{shared.sd_upscalers[extras_upscaler_1].name}\n"
            res = upscale(image, extras_upscaler_1, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)

    def run_prepare_crop(image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        # Actual crop happens in run_upscalers_blend, this just sets upscaling_resize and adds info text
        nonlocal upscaling_resize
        if resize_mode == 1:
            upscaling_resize = max(upscaling_resize_w/image.width, upscaling_resize_h/image.height)
            crop_info = " (crop)" if upscaling_crop else ""
            info += f"Resize to: {upscaling_resize_w:g}x{upscaling_resize_h:g}{crop_info}\n"
        return (image, info)

    @dataclass
    class UpscaleParams:
        upscaler_idx: int
        blend_alpha: float

    def run_upscalers_blend( params: List[UpscaleParams], image: Image.Image, info: str) -> Tuple[Image.Image, str]:
        blended_result: Image.Image = None
        for upscaler in params:
            res = upscale(image, upscaler.upscaler_idx, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
            info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {upscaler.blend_alpha}, model:{shared.sd_upscalers[upscaler.upscaler_idx].name}\n"
            if blended_result is None:
                blended_result = res
            else:
                blended_result = Image.blend(blended_result, res, upscaler.blend_alpha)
        return (blended_result, info)

    # Build a list of operations to run
    facefix_ops: List[Callable] = []
    if gfpgan_visibility > 0:
        facefix_ops.append(run_gfpgan)
    if codeformer_visibility > 0:
        facefix_ops.append(run_codeformer)

    upscale_ops: List[Callable] = []
    if resize_mode == 1:
        upscale_ops.append(run_prepare_crop)

    if upscaling_resize != 0:
        step_params: List[UpscaleParams] = []
        step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_1, blend_alpha=1.0 ))
        if extras_upscaler_2 != 0 and extras_upscaler_2_visibility > 0:
                res2 = upscale(image, extras_upscaler_2, upscaling_resize, resize_mode, upscaling_resize_w, upscaling_resize_h, upscaling_crop)
                info += f"Upscale: {round(upscaling_resize, 3)}, visibility: {round(extras_upscaler_2_visibility, 3)}, model:{shared.sd_upscalers[extras_upscaler_2].name}\n"
                res = Image.blend(res, res2, extras_upscaler_2_visibility)
            step_params.append( UpscaleParams( upscaler_idx=extras_upscaler_2, blend_alpha=extras_upscaler_2_visibility ) )

        upscale_ops.append( partial(run_upscalers_blend, step_params) )

            image = res

    extras_ops: List[Callable] = []
    if upscale_first:
        extras_ops = upscale_ops + facefix_ops
    else:
        extras_ops = facefix_ops + upscale_ops


    for image, image_name in zip(imageArr, imageNameArr):
        if image is None:
            return outputs, "Please select an input image.", ''
        existing_pnginfo = image.info or {}

        image = image.convert("RGB")
        info = ""
        # Run each operation on each image
        for op in extras_ops:
            image, info = op(image, info)

        while len(cached_images) > 2:
            del cached_images[next(iter(cached_images.keys()))]
+4 −0
Original line number Diff line number Diff line
@@ -1119,6 +1119,9 @@ def create_ui(wrap_gradio_gpu_call):
                    codeformer_visibility = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer visibility", value=0, interactive=modules.codeformer_model.have_codeformer)
                    codeformer_weight = gr.Slider(minimum=0.0, maximum=1.0, step=0.001, label="CodeFormer weight (0 = maximum effect, 1 = minimum effect)", value=0, interactive=modules.codeformer_model.have_codeformer)

                with gr.Group():
                    upscale_before_face_fix = gr.Checkbox(label='Upscale Before Restoring Faces', value=False)

                submit = gr.Button('Generate', elem_id="extras_generate", variant='primary')

            with gr.Column(variant='panel'):
@@ -1152,6 +1155,7 @@ def create_ui(wrap_gradio_gpu_call):
                extras_upscaler_1,
                extras_upscaler_2,
                extras_upscaler_2_visibility,
                upscale_before_face_fix,
            ],
            outputs=[
                result_images,