Commit 740070ea authored by d8ahazard's avatar d8ahazard
Browse files

Re-implement universal model loading

parent bfb7f15d
Loading
Loading
Loading
Loading
+25 −10
Original line number Diff line number Diff line
@@ -5,22 +5,28 @@ import traceback
import cv2
import torch

from modules import shared, devices
from modules.paths import script_path
from modules import shared, devices, modelloader
from modules.paths import script_path, models_path
import modules.shared
import modules.face_restoration
from importlib import reload

# codeformer people made a choice to include modified basicsr librry to their projectwhich makes
# it utterly impossiblr to use it alongside with other libraries that also use basicsr, like GFPGAN.
# codeformer people made a choice to include modified basicsr library to their project, which makes
# it utterly impossible to use it alongside other libraries that also use basicsr, like GFPGAN.
# I am making a choice to include some files from codeformer to work around this issue.

pretrain_model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'
model_dir = "Codeformer"
model_path = os.path.join(models_path, model_dir)
model_url = 'https://github.com/sczhou/CodeFormer/releases/download/v0.1.0/codeformer.pth'

have_codeformer = False
codeformer = None

def setup_codeformer():

def setup_model(dirname):
    global model_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    path = modules.paths.paths.get("CodeFormer", None)
    if path is None:
        return
@@ -44,16 +50,22 @@ def setup_codeformer():
            def name(self):
                return "CodeFormer"

            def __init__(self):
            def __init__(self, dirname):
                self.net = None
                self.face_helper = None
                self.cmd_dir = dirname

            def create_models(self):

                if self.net is not None and self.face_helper is not None:
                    self.net.to(devices.device_codeformer)
                    return self.net, self.face_helper

                model_paths = modelloader.load_models(model_path, model_url, self.cmd_dir)
                if len(model_paths) != 0:
                    ckpt_path = model_paths[0]
                else:
                    print("Unable to load codeformer model.")
                    return None, None
                net = net_class(dim_embd=512, codebook_size=1024, n_head=8, n_layers=9, connect_list=['32', '64', '128', '256']).to(devices.device_codeformer)
                ckpt_path = load_file_from_url(url=pretrain_model_url, model_dir=os.path.join(path, 'weights/CodeFormer'), progress=True)
                checkpoint = torch.load(ckpt_path)['params_ema']
@@ -74,6 +86,9 @@ def setup_codeformer():
                original_resolution = np_image.shape[0:2]

                self.create_models()
                if self.net is None or self.face_helper is None:
                    return np_image

                self.face_helper.clean_all()
                self.face_helper.read_image(np_image)
                self.face_helper.get_face_landmarks_5(only_center_face=False, resize=640, eye_dist_threshold=5)
@@ -114,7 +129,7 @@ def setup_codeformer():
        have_codeformer = True

        global codeformer
        codeformer = FaceRestorerCodeFormer()
        codeformer = FaceRestorerCodeFormer(dirname)
        shared.face_restorers.append(codeformer)

    except Exception:
+41 −15
Original line number Diff line number Diff line
@@ -5,15 +5,35 @@ import traceback
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url

import modules.esrgam_model_arch as arch
import modules.images
from modules import shared
from modules.shared import opts
from modules import shared, modelloader
from modules.devices import has_mps
import modules.images
from modules.paths import models_path
from modules.shared import opts

model_dir = "ESRGAN"
model_path = os.path.join(models_path, model_dir)
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
model_name = "ESRGAN_x4.pth"

def load_model(filename):

def load_model(path: str, name: str):
    global model_path
    global model_url
    global model_dir
    global model_name
    if "http" in path:
        filename = load_file_from_url(url=model_url, model_dir=model_path, file_name=model_name, progress=True)
    else:
        filename = path
    if not os.path.exists(filename) or filename is None:
        print("Unable to load %s from %s" % (model_dir, filename))
        return None
    print("Loading %s from %s" % (model_dir, filename))
    # this code is adapted from https://github.com/xinntao/ESRGAN
    pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
    crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
@@ -118,24 +138,30 @@ def esrgan_upscale(model, img):
class UpscalerESRGAN(modules.images.Upscaler):
    def __init__(self, filename, title):
        self.name = title
        self.model = load_model(filename)
        self.filename = filename

    def do_upscale(self, img):
        model = self.model.to(shared.device)
        model = load_model(self.filename, self.name)
        if model is None:
            return img
        model.to(shared.device)
        img = esrgan_upscale(model, img)
        return img


def load_models(dirname):
    for file in os.listdir(dirname):
        path = os.path.join(dirname, file)
        model_name, extension = os.path.splitext(file)

        if extension != '.pt' and extension != '.pth':
            continue
def setup_model(dirname):
    global model_path
    global model_name
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
    if len(model_paths) == 0:
        modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
    for file in model_paths:
        name = modelloader.friendly_name(file)
        try:
            modules.shared.sd_upscalers.append(UpscalerESRGAN(path, model_name))
            modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
        except Exception:
            print(f"Error loading ESRGAN model: {path}", file=sys.stderr)
            print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
+2 −0
Original line number Diff line number Diff line
@@ -36,6 +36,8 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v

    outputs = []
    for image, image_name in zip(imageArr, imageNameArr):
        if image is None:
            return outputs, "Please select an input image.", ''
        existing_pnginfo = image.info or {}

        image = image.convert("RGB")
+29 −31
Original line number Diff line number Diff line
@@ -7,33 +7,20 @@ from modules import shared, devices
from modules.shared import cmd_opts
from modules.paths import script_path
import modules.face_restoration
from modules import shared, devices, modelloader
from modules.paths import models_path


def gfpgan_model_path():
    from modules.shared import cmd_opts

    filemask = 'GFPGAN*.pth'

    if cmd_opts.gfpgan_model is not None:
        return cmd_opts.gfpgan_model

    places = [script_path, '.', os.path.join(cmd_opts.gfpgan_dir, 'experiments/pretrained_models')]

    filename = None
    for place in places:
        filename = next(iter(glob(os.path.join(place, filemask))), None)
        if filename is not None:
            break

    return filename

model_dir = "GFPGAN"
cmd_dir = None
model_path = os.path.join(models_path, model_dir)
model_url = "https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth"

loaded_gfpgan_model = None


def gfpgan():
    global loaded_gfpgan_model

    global model_path
    if loaded_gfpgan_model is not None:
        loaded_gfpgan_model.gfpgan.to(shared.device)
        return loaded_gfpgan_model
@@ -41,7 +28,15 @@ def gfpgan():
    if gfpgan_constructor is None:
        return None

    model = gfpgan_constructor(model_path=gfpgan_model_path() or 'https://github.com/TencentARC/GFPGAN/releases/download/v1.3.0/GFPGANv1.4.pth', upscale=1, arch='clean', channel_multiplier=2, bg_upsampler=None)
    models = modelloader.load_models(model_path, model_url, cmd_dir)
    if len(models) != 0:
        latest_file = max(models, key=os.path.getctime)
        model_file = latest_file
    else:
        print("Unable to load gfpgan model!")
        return None
    model = gfpgan_constructor(model_path=model_file, model_dir=model_path, upscale=1, arch='clean', channel_multiplier=2,
                               bg_upsampler=None)
    model.gfpgan.to(shared.device)
    loaded_gfpgan_model = model

@@ -50,7 +45,8 @@ def gfpgan():

def gfpgan_fix_faces(np_image):
    model = gfpgan()

    if model is None:
        return np_image
    np_image_bgr = np_image[:, :, ::-1]
    cropped_faces, restored_faces, gfpgan_output_bgr = model.enhance(np_image_bgr, has_aligned=False, only_center_face=False, paste_back=True)
    np_image = gfpgan_output_bgr[:, :, ::-1]
@@ -64,19 +60,21 @@ def gfpgan_fix_faces(np_image):
have_gfpgan = False
gfpgan_constructor = None

def setup_gfpgan():
    try:
        gfpgan_model_path()

        if os.path.exists(cmd_opts.gfpgan_dir):
            sys.path.append(os.path.abspath(cmd_opts.gfpgan_dir))
        from gfpgan import GFPGANer
def setup_model(dirname):
    global model_path
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    try:
        from modules.gfpgan_model_arch import GFPGANerr
        global cmd_dir
        global have_gfpgan
        have_gfpgan = True

        global gfpgan_constructor
        gfpgan_constructor = GFPGANer

        cmd_dir = dirname
        have_gfpgan = True
        gfpgan_constructor = GFPGANerr

        class FaceRestorerGFPGAN(modules.face_restoration.FaceRestoration):
            def name(self):
+150 −0
Original line number Diff line number Diff line
# GFPGAN likes to download stuff "wherever", and we're trying to fix that, so this is a copy of the original...

import cv2
import os
import torch
from basicsr.utils import img2tensor, tensor2img
from basicsr.utils.download_util import load_file_from_url
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
from torchvision.transforms.functional import normalize

from gfpgan.archs.gfpgan_bilinear_arch import GFPGANBilinear
from gfpgan.archs.gfpganv1_arch import GFPGANv1
from gfpgan.archs.gfpganv1_clean_arch import GFPGANv1Clean

ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))


class GFPGANerr():
    """Helper for restoration with GFPGAN.

    It will detect and crop faces, and then resize the faces to 512x512.
    GFPGAN is used to restored the resized faces.
    The background is upsampled with the bg_upsampler.
    Finally, the faces will be pasted back to the upsample background image.

    Args:
        model_path (str): The path to the GFPGAN model. It can be urls (will first download it automatically).
        upscale (float): The upscale of the final output. Default: 2.
        arch (str): The GFPGAN architecture. Option: clean | original. Default: clean.
        channel_multiplier (int): Channel multiplier for large networks of StyleGAN2. Default: 2.
        bg_upsampler (nn.Module): The upsampler for the background. Default: None.
    """

    def __init__(self, model_path, model_dir, upscale=2, arch='clean', channel_multiplier=2, bg_upsampler=None, device=None):
        self.upscale = upscale
        self.bg_upsampler = bg_upsampler

        # initialize model
        self.device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') if device is None else device
        # initialize the GFP-GAN
        if arch == 'clean':
            self.gfpgan = GFPGANv1Clean(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=False,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)
        elif arch == 'bilinear':
            self.gfpgan = GFPGANBilinear(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=False,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)
        elif arch == 'original':
            self.gfpgan = GFPGANv1(
                out_size=512,
                num_style_feat=512,
                channel_multiplier=channel_multiplier,
                decoder_load_path=None,
                fix_decoder=True,
                num_mlp=8,
                input_is_latent=True,
                different_w=True,
                narrow=1,
                sft_half=True)
        elif arch == 'RestoreFormer':
            from gfpgan.archs.restoreformer_arch import RestoreFormer
            self.gfpgan = RestoreFormer()
        # initialize face helper
        self.face_helper = FaceRestoreHelper(
            upscale,
            face_size=512,
            crop_ratio=(1, 1),
            det_model='retinaface_resnet50',
            save_ext='png',
            use_parse=True,
            device=self.device,
            model_rootpath=model_dir)

        if model_path.startswith('https://'):
            model_path = load_file_from_url(
                url=model_path, model_dir=model_dir, progress=True, file_name=None)
        loadnet = torch.load(model_path)
        if 'params_ema' in loadnet:
            keyname = 'params_ema'
        else:
            keyname = 'params'
        self.gfpgan.load_state_dict(loadnet[keyname], strict=True)
        self.gfpgan.eval()
        self.gfpgan = self.gfpgan.to(self.device)

    @torch.no_grad()
    def enhance(self, img, has_aligned=False, only_center_face=False, paste_back=True, weight=0.5):
        self.face_helper.clean_all()

        if has_aligned:  # the inputs are already aligned
            img = cv2.resize(img, (512, 512))
            self.face_helper.cropped_faces = [img]
        else:
            self.face_helper.read_image(img)
            # get face landmarks for each face
            self.face_helper.get_face_landmarks_5(only_center_face=only_center_face, eye_dist_threshold=5)
            # eye_dist_threshold=5: skip faces whose eye distance is smaller than 5 pixels
            # TODO: even with eye_dist_threshold, it will still introduce wrong detections and restorations.
            # align and warp each face
            self.face_helper.align_warp_face()

        # face restoration
        for cropped_face in self.face_helper.cropped_faces:
            # prepare data
            cropped_face_t = img2tensor(cropped_face / 255., bgr2rgb=True, float32=True)
            normalize(cropped_face_t, (0.5, 0.5, 0.5), (0.5, 0.5, 0.5), inplace=True)
            cropped_face_t = cropped_face_t.unsqueeze(0).to(self.device)

            try:
                output = self.gfpgan(cropped_face_t, return_rgb=False, weight=weight)[0]
                # convert to image
                restored_face = tensor2img(output.squeeze(0), rgb2bgr=True, min_max=(-1, 1))
            except RuntimeError as error:
                print(f'\tFailed inference for GFPGAN: {error}.')
                restored_face = cropped_face

            restored_face = restored_face.astype('uint8')
            self.face_helper.add_restored_face(restored_face)

        if not has_aligned and paste_back:
            # upsample the background
            if self.bg_upsampler is not None:
                # Now only support RealESRGAN for upsampling background
                bg_img = self.bg_upsampler.enhance(img, outscale=self.upscale)[0]
            else:
                bg_img = None

            self.face_helper.get_inverse_affine(None)
            # paste each restored face to the input image
            restored_img = self.face_helper.paste_faces_to_input_image(upsample_img=bg_img)
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, restored_img
        else:
            return self.face_helper.cropped_faces, self.face_helper.restored_faces, None
Loading