Commit bf67a5dc authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Upscaler.load_model: don't return None, just use exceptions

parent e3a973a6
Loading
Loading
Loading
Loading
+5 −8
Original line number Diff line number Diff line
@@ -46,16 +46,13 @@ class UpscalerLDSR(Upscaler):

        yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")

        try:
        return LDSR(model, yaml)
        except Exception:
            errors.report("Error importing LDSR", exc_info=True)
        return None

    def do_upscale(self, img, path):
        try:
            ldsr = self.load_model(path)
        if ldsr is None:
            print("NO LDSR!")
        except Exception:
            errors.report(f"Failed loading LDSR model {path}", exc_info=True)
            return img
        ddim_steps = shared.opts.ldsr_steps
        return ldsr.super_resolution(img, ddim_steps, self.scale)
+6 −10
Original line number Diff line number Diff line
import os.path
import sys

import PIL.Image
@@ -8,7 +7,7 @@ from tqdm import tqdm

import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net
from scunet_model_arch import SCUNet

from modules.modelloader import load_file_from_url
from modules.shared import opts
@@ -88,9 +87,10 @@ class UpscalerScuNET(modules.upscaler.Upscaler):

        torch.cuda.empty_cache()

        try:
            model = self.load_model(selected_file)
        if model is None:
            print(f"ScuNET: Unable to load model from {selected_file}", file=sys.stderr)
        except Exception as e:
            print(f"ScuNET: Unable to load model from {selected_file}: {e}", file=sys.stderr)
            return img

        device = devices.get_device_for('scunet')
@@ -123,11 +123,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
            filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
        else:
            filename = path
        if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
            print(f"ScuNET: Unable to load model from {filename}", file=sys.stderr)
            return None

        model = net(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
        model = SCUNet(in_nc=3, config=[4, 4, 4, 4, 4, 4, 4], dim=64)
        model.load_state_dict(torch.load(filename), strict=True)
        model.eval()
        for _, v in model.named_parameters():
+20 −20
Original line number Diff line number Diff line
import os
import sys

import numpy as np
import torch
@@ -7,8 +7,8 @@ from tqdm import tqdm

from modules import modelloader, devices, script_callbacks, shared
from modules.shared import opts, state
from swinir_model_arch import SwinIR as net
from swinir_model_arch_v2 import Swin2SR as net2
from swinir_model_arch import SwinIR
from swinir_model_arch_v2 import Swin2SR
from modules.upscaler import Upscaler, UpscalerData


@@ -36,8 +36,10 @@ class UpscalerSwinIR(Upscaler):
        self.scalers = scalers

    def do_upscale(self, img, model_file):
        try:
            model = self.load_model(model_file)
        if model is None:
        except Exception as e:
            print(f"Failed loading SwinIR model {model_file}: {e}", file=sys.stderr)
            return img
        model = model.to(device_swinir, dtype=devices.dtype)
        img = upscale(img, model)
@@ -56,10 +58,8 @@ class UpscalerSwinIR(Upscaler):
            )
        else:
            filename = path
        if filename is None or not os.path.exists(filename):
            return None
        if filename.endswith(".v2.pth"):
            model = net2(
            model = Swin2SR(
                upscale=scale,
                in_chans=3,
                img_size=64,
@@ -74,7 +74,7 @@ class UpscalerSwinIR(Upscaler):
            )
            params = None
        else:
            model = net(
            model = SwinIR(
                upscale=scale,
                in_chans=3,
                img_size=64,
+6 −8
Original line number Diff line number Diff line
import os
import sys

import numpy as np
import torch
@@ -6,9 +6,8 @@ from PIL import Image

import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
from modules.upscaler import Upscaler, UpscalerData
from modules.shared import opts

from modules.upscaler import Upscaler, UpscalerData


def mod2normal(state_dict):
@@ -142,8 +141,10 @@ class UpscalerESRGAN(Upscaler):
            self.scalers.append(scaler_data)

    def do_upscale(self, img, selected_model):
        try:
            model = self.load_model(selected_model)
        if model is None:
        except Exception as e:
            print(f"Unable to load ESRGAN model {selected_model}: {e}", file=sys.stderr)
            return img
        model.to(devices.device_esrgan)
        img = esrgan_upscale(model, img)
@@ -159,9 +160,6 @@ class UpscalerESRGAN(Upscaler):
            )
        else:
            filename = path
        if not os.path.exists(filename) or filename is None:
            print(f"Unable to load {self.model_path} from {filename}")
            return None

        state_dict = torch.load(filename, map_location='cpu' if devices.device_esrgan.type == 'mps' else None)

+15 −18
Original line number Diff line number Diff line
@@ -9,7 +9,6 @@ from modules.shared import cmd_opts, opts
from modules import modelloader, errors



class UpscalerRealESRGAN(Upscaler):
    def __init__(self, path):
        self.name = "RealESRGAN"
@@ -43,9 +42,10 @@ class UpscalerRealESRGAN(Upscaler):
        if not self.enable:
            return img

        try:
            info = self.load_model(path)
        if not os.path.exists(info.local_data_path):
            print(f"Unable to load RealESRGAN model: {info.name}")
        except Exception:
            errors.report(f"Unable to load RealESRGAN model {path}", exc_info=True)
            return img

        upsampler = RealESRGANer(
@@ -63,20 +63,17 @@ class UpscalerRealESRGAN(Upscaler):
        return image

    def load_model(self, path):
        try:
            info = next(iter([scaler for scaler in self.scalers if scaler.data_path == path]), None)

            if info is None:
                print(f"Unable to find model info: {path}")
                return None

            if info.local_data_path.startswith("http"):
                info.local_data_path = modelloader.load_file_from_url(info.data_path, model_dir=self.model_download_path)

            return info
        except Exception:
            errors.report("Error making Real-ESRGAN models list", exc_info=True)
        return None
        for scaler in self.scalers:
            if scaler.data_path == path:
                if scaler.local_data_path.startswith("http"):
                    scaler.local_data_path = modelloader.load_file_from_url(
                        scaler.data_path,
                        model_dir=self.model_download_path,
                    )
                if not os.path.exists(scaler.local_data_path):
                    raise FileNotFoundError(f"RealESRGAN data missing: {scaler.local_data_path}")
                return scaler
        raise ValueError(f"Unable to find model info: {path}")

    def load_models(self, _):
        return get_realesrgan_models(self)