Commit 89352a2f authored by Aarni Koskela's avatar Aarni Koskela
Browse files

Move `load_file_from_url` to modelloader

parent 59419bd6
Loading
Loading
Loading
Loading
+3 −4
Original line number Diff line number Diff line
import os

from basicsr.utils.download_util import load_file_from_url

from modules.modelloader import load_file_from_url
from modules.upscaler import Upscaler, UpscalerData
from ldsr_model_arch import LDSR
from modules import shared, script_callbacks, errors
@@ -43,9 +42,9 @@ class UpscalerLDSR(Upscaler):
        if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
            model = local_safetensors_path
        else:
            model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)
            model = local_ckpt_path or load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name="model.ckpt")

        yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)
        yaml = local_yaml_path or load_file_from_url(self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml")

        try:
            return LDSR(model, yaml)
+2 −3
Original line number Diff line number Diff line
@@ -6,12 +6,11 @@ import numpy as np
import torch
from tqdm import tqdm

from basicsr.utils.download_util import load_file_from_url

import modules.upscaler
from modules import devices, modelloader, script_callbacks, errors
from scunet_model_arch import SCUNet as net

from modules.modelloader import load_file_from_url
from modules.shared import opts


@@ -120,7 +119,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
    def load_model(self, path: str):
        device = devices.get_device_for('scunet')
        if "http" in path:
            filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
            filename = load_file_from_url(self.model_url, model_dir=self.model_download_path, file_name=f"{self.name}.pth")
        else:
            filename = path
        if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
+5 −3
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import os
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url
from tqdm import tqdm

from modules import modelloader, devices, script_callbacks, shared
@@ -50,8 +49,11 @@ class UpscalerSwinIR(Upscaler):

    def load_model(self, path, scale=4):
        if "http" in path:
            dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
            filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
            filename = modelloader.load_file_from_url(
                url=path,
                model_dir=self.model_download_path,
                file_name=f"{self.model_name.replace(' ', '_')}.pth",
            )
        else:
            filename = path
        if filename is None or not os.path.exists(filename):
+1 −3
Original line number Diff line number Diff line
@@ -3,7 +3,6 @@ import os
import numpy as np
import torch
from PIL import Image
from basicsr.utils.download_util import load_file_from_url

import modules.esrgan_model_arch as arch
from modules import modelloader, images, devices
@@ -152,11 +151,10 @@ class UpscalerESRGAN(Upscaler):

    def load_model(self, path: str):
        if "http" in path:
            filename = load_file_from_url(
            filename = modelloader.load_file_from_url(
                url=self.model_url,
                model_dir=self.model_download_path,
                file_name=f"{self.model_name}.pth",
                progress=True,
            )
        else:
            filename = path
+26 −3
Original line number Diff line number Diff line
from __future__ import annotations

import os
import shutil
import importlib
@@ -8,6 +10,29 @@ from modules.upscaler import Upscaler, UpscalerLanczos, UpscalerNearest, Upscale
from modules.paths import script_path, models_path


def load_file_from_url(
    url: str,
    *,
    model_dir: str,
    progress: bool = True,
    file_name: str | None = None,
) -> str:
    """Download a file from `url` into `model_dir`, using the file present if possible.

    Returns the path to the downloaded file.
    """
    os.makedirs(model_dir, exist_ok=True)
    if not file_name:
        parts = urlparse(url)
        file_name = os.path.basename(parts.path)
    cached_file = os.path.abspath(os.path.join(model_dir, file_name))
    if not os.path.exists(cached_file):
        print(f'Downloading: "{url}" to {cached_file}\n')
        from torch.hub import download_url_to_file
        download_url_to_file(url, cached_file, progress=progress)
    return cached_file


def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
    """
    A one-and done loader to try finding the desired models in specified directories.
@@ -46,9 +71,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None

        if model_url is not None and len(output) == 0:
            if download_name is not None:
                from basicsr.utils.download_util import load_file_from_url
                dl = load_file_from_url(model_url, places[0], True, download_name)
                output.append(dl)
                output.append(load_file_from_url(model_url, model_dir=places[0], file_name=download_name))
            else:
                output.append(model_url)

Loading