Commit 0dce0df1 authored by d8ahazard's avatar d8ahazard
Browse files

Holy $hit.

Yep.

Fix gfpgan_model_arch requirement(s).
Add Upscaler base class, move from images.
Add a lot of methods to Upscaler.
Re-work all the child upscalers to be proper classes.
Add BSRGAN scaler.
Add ldsr_model_arch class, removing the dependency for another repo that just uses regular latent-diffusion stuff.
Add one universal method that will always find and load new upscaler models without having to add new "setup_model" calls. Still need to add command line params, but that could probably be automated.
Add a "self.scale" property to all Upscalers so the scalers themselves can do "things" in response to the requested upscaling size.
Ensure LDSR doesn't get stuck in a longer loop of "upscale/downscale/upscale" as we try to reach the target upscale size.
Add typehints for IDE sanity.
PEP-8 improvements.
Moar.
parent 31ad536c
Loading
Loading
Loading
Loading
+6 −5
Original line number Diff line number Diff line
# this scripts installs necessary requirements and launches main program in webui.py

import shutil
import subprocess
import os
import sys
@@ -22,7 +22,6 @@ stable_diffusion_commit_hash = os.environ.get('STABLE_DIFFUSION_COMMIT_HASH', "6
taming_transformers_commit_hash = os.environ.get('TAMING_TRANSFORMERS_COMMIT_HASH', "24268930bf1dce879235a7fddd0b2355b84d7ea6")
codeformer_commit_hash = os.environ.get('CODEFORMER_COMMIT_HASH', "c5b4593074ba6214284d6acd5f1719b6c5d739af")
blip_commit_hash = os.environ.get('BLIP_COMMIT_HASH', "48211a1594f1321b00f14c9f7a5b4813144b2fb9")
ldsr_commit_hash = os.environ.get('LDSR_COMMIT_HASH',"abf33e7002d59d9085081bce93ec798dcabd49af")

args = shlex.split(commandline_args)

@@ -122,9 +121,11 @@ git_clone("https://github.com/CompVis/stable-diffusion.git", repo_dir('stable-di
git_clone("https://github.com/CompVis/taming-transformers.git", repo_dir('taming-transformers'), "Taming Transformers", taming_transformers_commit_hash)
git_clone("https://github.com/sczhou/CodeFormer.git", repo_dir('CodeFormer'), "CodeFormer", codeformer_commit_hash)
git_clone("https://github.com/salesforce/BLIP.git", repo_dir('BLIP'), "BLIP", blip_commit_hash)
# Using my repo until my changes are merged, as this makes interfacing with our version of SD-web a lot easier
git_clone("https://github.com/Hafiidz/latent-diffusion", repo_dir('latent-diffusion'), "LDSR", ldsr_commit_hash)

if os.path.isdir(repo_dir('latent-diffusion')):
    try:
        shutil.rmtree(repo_dir('latent-diffusion'))
    except:
        pass
if not is_installed("lpips"):
    run_pip(f"install -r {os.path.join(repo_dir('CodeFormer'), 'requirements.txt')}", "requirements for CodeFormer")

+79 −0
Original line number Diff line number Diff line
import os.path
import sys
import traceback

import PIL.Image
import numpy as np
import torch
from basicsr.utils.download_util import load_file_from_url

import modules.upscaler
from modules import shared, modelloader
from modules.bsrgan_model_arch import RRDBNet
from modules.paths import models_path


class UpscalerBSRGAN(modules.upscaler.Upscaler):
    def __init__(self, dirname):
        self.name = "BSRGAN"
        self.model_path = os.path.join(models_path, self.name)
        self.model_name = "BSRGAN 4x"
        self.model_url = "https://github.com/cszn/KAIR/releases/download/v1.0/BSRGAN.pth"
        self.user_path = dirname
        super().__init__()
        model_paths = self.find_models(ext_filter=[".pt", ".pth"])
        scalers = []
        if len(model_paths) == 0:
            scaler_data = modules.upscaler.UpscalerData(self.model_name, self.model_url, self, 4)
            scalers.append(scaler_data)
        for file in model_paths:
            if "http" in file:
                name = self.model_name
            else:
                name = modelloader.friendly_name(file)
            try:
                scaler_data = modules.upscaler.UpscalerData(name, file, self, 4)
                scalers.append(scaler_data)
            except Exception:
                print(f"Error loading BSRGAN model: {file}", file=sys.stderr)
                print(traceback.format_exc(), file=sys.stderr)
        self.scalers = scalers

    def do_upscale(self, img: PIL.Image, selected_file):
        torch.cuda.empty_cache()
        model = self.load_model(selected_file)
        if model is None:
            return img
        model.to(shared.device)
        torch.cuda.empty_cache()
        img = np.array(img)
        img = img[:, :, ::-1]
        img = np.moveaxis(img, 2, 0) / 255
        img = torch.from_numpy(img).float()
        img = img.unsqueeze(0).to(shared.device)
        with torch.no_grad():
            output = model(img)
        output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
        output = 255. * np.moveaxis(output, 0, 2)
        output = output.astype(np.uint8)
        output = output[:, :, ::-1]
        torch.cuda.empty_cache()
        return PIL.Image.fromarray(output, 'RGB')

    def load_model(self, path: str):
        if "http" in path:
            filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
                                          progress=True)
        else:
            filename = path
        if not os.path.exists(filename) or filename is None:
            print("Unable to load %s from %s" % (self.model_dir, filename))
            return None
        print("Loading %s from %s" % (self.model_dir, filename))
        model = RRDBNet(in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=2)  # define network
        model.load_state_dict(torch.load(filename), strict=True)
        model.eval()
        for k, v in model.named_parameters():
            v.requires_grad = False
        return model
+103 −0
Original line number Diff line number Diff line
import functools
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init


def initialize_weights(net_l, scale=1):
    if not isinstance(net_l, list):
        net_l = [net_l]
    for net in net_l:
        for m in net.modules():
            if isinstance(m, nn.Conv2d):
                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                m.weight.data *= scale  # for residual block
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.Linear):
                init.kaiming_normal_(m.weight, a=0, mode='fan_in')
                m.weight.data *= scale
                if m.bias is not None:
                    m.bias.data.zero_()
            elif isinstance(m, nn.BatchNorm2d):
                init.constant_(m.weight, 1)
                init.constant_(m.bias.data, 0.0)


def make_layer(block, n_layers):
    layers = []
    for _ in range(n_layers):
        layers.append(block())
    return nn.Sequential(*layers)


class ResidualDenseBlock_5C(nn.Module):
    def __init__(self, nf=64, gc=32, bias=True):
        super(ResidualDenseBlock_5C, self).__init__()
        # gc: growth channel, i.e. intermediate channels
        self.conv1 = nn.Conv2d(nf, gc, 3, 1, 1, bias=bias)
        self.conv2 = nn.Conv2d(nf + gc, gc, 3, 1, 1, bias=bias)
        self.conv3 = nn.Conv2d(nf + 2 * gc, gc, 3, 1, 1, bias=bias)
        self.conv4 = nn.Conv2d(nf + 3 * gc, gc, 3, 1, 1, bias=bias)
        self.conv5 = nn.Conv2d(nf + 4 * gc, nf, 3, 1, 1, bias=bias)
        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

        # initialization
        initialize_weights([self.conv1, self.conv2, self.conv3, self.conv4, self.conv5], 0.1)

    def forward(self, x):
        x1 = self.lrelu(self.conv1(x))
        x2 = self.lrelu(self.conv2(torch.cat((x, x1), 1)))
        x3 = self.lrelu(self.conv3(torch.cat((x, x1, x2), 1)))
        x4 = self.lrelu(self.conv4(torch.cat((x, x1, x2, x3), 1)))
        x5 = self.conv5(torch.cat((x, x1, x2, x3, x4), 1))
        return x5 * 0.2 + x


class RRDB(nn.Module):
    '''Residual in Residual Dense Block'''

    def __init__(self, nf, gc=32):
        super(RRDB, self).__init__()
        self.RDB1 = ResidualDenseBlock_5C(nf, gc)
        self.RDB2 = ResidualDenseBlock_5C(nf, gc)
        self.RDB3 = ResidualDenseBlock_5C(nf, gc)

    def forward(self, x):
        out = self.RDB1(x)
        out = self.RDB2(out)
        out = self.RDB3(out)
        return out * 0.2 + x


class RRDBNet(nn.Module):
    def __init__(self, in_nc=3, out_nc=3, nf=64, nb=23, gc=32, sf=4):
        super(RRDBNet, self).__init__()
        RRDB_block_f = functools.partial(RRDB, nf=nf, gc=gc)
        self.sf = sf
        print([in_nc, out_nc, nf, nb, gc, sf])

        self.conv_first = nn.Conv2d(in_nc, nf, 3, 1, 1, bias=True)
        self.RRDB_trunk = make_layer(RRDB_block_f, nb)
        self.trunk_conv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        #### upsampling
        self.upconv1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        if self.sf==4:
            self.upconv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.HRconv = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
        self.conv_last = nn.Conv2d(nf, out_nc, 3, 1, 1, bias=True)

        self.lrelu = nn.LeakyReLU(negative_slope=0.2, inplace=True)

    def forward(self, x):
        fea = self.conv_first(x)
        trunk = self.trunk_conv(self.RRDB_trunk(fea))
        fea = fea + trunk

        fea = self.lrelu(self.upconv1(F.interpolate(fea, scale_factor=2, mode='nearest')))
        if self.sf==4:
            fea = self.lrelu(self.upconv2(F.interpolate(fea, scale_factor=2, mode='nearest')))
        out = self.conv_last(self.lrelu(self.HRconv(fea)))

        return out
 No newline at end of file
+110 −117
Original line number Diff line number Diff line
import os
import sys
import traceback

import numpy as np
import torch
@@ -8,32 +6,56 @@ from PIL import Image
from basicsr.utils.download_util import load_file_from_url

import modules.esrgam_model_arch as arch
import modules.images
from modules import shared
from modules import shared, modelloader
from modules import shared, modelloader, images
from modules.devices import has_mps
from modules.paths import models_path
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts

model_dir = "ESRGAN"
model_path = os.path.join(models_path, model_dir)
model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
model_name = "ESRGAN_x4"

class UpscalerESRGAN(Upscaler):
    def __init__(self, dirname):
        self.name = "ESRGAN"
        self.model_url = "https://drive.google.com/u/0/uc?id=1TPrz5QKd8DHHt1k8SRtm6tMiPjz_Qene&export=download"
        self.model_name = "ESRGAN 4x"
        self.scalers = []
        self.user_path = dirname
        self.model_path = os.path.join(models_path, self.name)
        super().__init__()
        model_paths = self.find_models(ext_filter=[".pt", ".pth"])
        scalers = []
        if len(model_paths) == 0:
            scaler_data = UpscalerData(self.model_name, self.model_url, self, 4)
            scalers.append(scaler_data)
        for file in model_paths:
            print(f"File: {file}")
            if "http" in file:
                name = self.model_name
            else:
                name = modelloader.friendly_name(file)

def load_model(path: str, name: str):
    global model_path
    global model_url
    global model_dir
    global model_name
            scaler_data = UpscalerData(name, file, self, 4)
            print(f"ESRGAN: Adding scaler {name}")
            self.scalers.append(scaler_data)

    def do_upscale(self, img, selected_model):
        model = self.load_model(selected_model)
        if model is None:
            return img
        model.to(shared.device)
        img = esrgan_upscale(model, img)
        return img

    def load_model(self, path: str):
        if "http" in path:
        filename = load_file_from_url(url=model_url, model_dir=model_path, file_name="%s.pth" % model_name, progress=True)
            filename = load_file_from_url(url=self.model_url, model_dir=self.model_path,
                                          file_name="%s.pth" % self.model_name,
                                          progress=True)
        else:
            filename = path
        if not os.path.exists(filename) or filename is None:
        print("Unable to load %s from %s" % (model_dir, filename))
            print("Unable to load %s from %s" % (self.model_path, filename))
            return None
    print("Loading %s from %s" % (model_dir, filename))
        # this code is adapted from https://github.com/xinntao/ESRGAN
        pretrained_net = torch.load(filename, map_location='cpu' if has_mps else None)
        crt_model = arch.RRDBNet(3, 3, 64, 23, gc=32)
@@ -43,7 +65,8 @@ def load_model(path: str, name: str):
            return crt_model

        if 'model.0.weight' not in pretrained_net:
        is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net["params_ema"]
            is_realesrgan = "params_ema" in pretrained_net and 'body.0.rdb1.conv1.weight' in pretrained_net[
                "params_ema"]
            if is_realesrgan:
                raise Exception("The file is a RealESRGAN model, it can't be used as a ESRGAN model.")
            else:
@@ -96,6 +119,7 @@ def load_model(path: str, name: str):
        crt_model.eval()
        return crt_model


def upscale_without_tiling(model, img):
    img = np.array(img)
    img = img[:, :, ::-1]
@@ -115,7 +139,7 @@ def esrgan_upscale(model, img):
    if opts.ESRGAN_tile == 0:
        return upscale_without_tiling(model, img)

    grid = modules.images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
    grid = images.split_grid(img, opts.ESRGAN_tile, opts.ESRGAN_tile, opts.ESRGAN_tile_overlap)
    newtiles = []
    scale_factor = 1

@@ -130,38 +154,7 @@ def esrgan_upscale(model, img):
            newrow.append([x * scale_factor, w * scale_factor, output])
        newtiles.append([y * scale_factor, h * scale_factor, newrow])

    newgrid = modules.images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor, grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
    output = modules.images.combine_grid(newgrid)
    newgrid = images.Grid(newtiles, grid.tile_w * scale_factor, grid.tile_h * scale_factor,
                                  grid.image_w * scale_factor, grid.image_h * scale_factor, grid.overlap * scale_factor)
    output = images.combine_grid(newgrid)
    return output


class UpscalerESRGAN(modules.images.Upscaler):
    def __init__(self, filename, title):
        self.name = title
        self.filename = filename

    def do_upscale(self, img):
        model = load_model(self.filename, self.name)
        if model is None:
            return img
        model.to(shared.device)
        img = esrgan_upscale(model, img)
        return img


def setup_model(dirname):
    global model_path
    global model_name
    if not os.path.exists(model_path):
        os.makedirs(model_path)

    model_paths = modelloader.load_models(model_path, command_path=dirname, ext_filter=[".pt", ".pth"])
    if len(model_paths) == 0:
        modules.shared.sd_upscalers.append(UpscalerESRGAN(model_url, model_name))
    for file in model_paths:
        name = modelloader.friendly_name(file)
        try:
            modules.shared.sd_upscalers.append(UpscalerESRGAN(file, name))
        except Exception:
            print(f"Error loading ESRGAN model: {file}", file=sys.stderr)
            print(traceback.format_exc(), file=sys.stderr)
+17 −18
Original line number Diff line number Diff line
@@ -66,7 +66,6 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
            info += f"CodeFormer w: {round(codeformer_weight, 2)}, CodeFormer visibility:{round(codeformer_visibility, 2)}\n"
            image = res

        if upscaling_resize != 1.0:
        def upscale(image, scaler_index, resize):
            small = image.crop((image.width // 2, image.height // 2, image.width // 2 + 10, image.height // 2 + 10))
            pixels = tuple(np.array(small).flatten().tolist())
@@ -75,7 +74,7 @@ def run_extras(extras_mode, image, image_folder, gfpgan_visibility, codeformer_v
            c = cached_images.get(key)
            if c is None:
                upscaler = shared.sd_upscalers[scaler_index]
                    c = upscaler.upscale(image, image.width * resize, image.height * resize)
                c = upscaler.scaler.upscale(image, resize, upscaler.data_path)
                cached_images[key] = c

            return c
Loading