Commit df6fffb0 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

change upscalers to download models into user-specified directory (from...

change upscalers to download models into user-specified directory (from commandline args) rather than the default models/<...>
parent 379fd620
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -45,9 +45,9 @@ class UpscalerLDSR(Upscaler):
        if local_safetensors_path is not None and os.path.exists(local_safetensors_path):
            model = local_safetensors_path
        else:
            model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="model.ckpt", progress=True)
            model = local_ckpt_path if local_ckpt_path is not None else load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="model.ckpt", progress=True)

        yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_path, file_name="project.yaml", progress=True)
        yaml = local_yaml_path if local_yaml_path is not None else load_file_from_url(url=self.yaml_url, model_dir=self.model_download_path, file_name="project.yaml", progress=True)

        try:
            return LDSR(model, yaml)
+1 −2
Original line number Diff line number Diff line
@@ -121,8 +121,7 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
    def load_model(self, path: str):
        device = devices.get_device_for('scunet')
        if "http" in path:
            filename = load_file_from_url(url=self.model_url, model_dir=self.model_path, file_name="%s.pth" % self.name,
                                          progress=True)
            filename = load_file_from_url(url=self.model_url, model_dir=self.model_download_path, file_name="%s.pth" % self.name, progress=True)
        else:
            filename = path
        if not os.path.exists(os.path.join(self.model_path, filename)) or filename is None:
+1 −1
Original line number Diff line number Diff line
@@ -51,7 +51,7 @@ class UpscalerSwinIR(Upscaler):
    def load_model(self, path, scale=4):
        if "http" in path:
            dl_name = "%s%s" % (self.model_name.replace(" ", "_"), ".pth")
            filename = load_file_from_url(url=path, model_dir=self.model_path, file_name=dl_name, progress=True)
            filename = load_file_from_url(url=path, model_dir=self.model_download_path, file_name=dl_name, progress=True)
        else:
            filename = path
        if filename is None or not os.path.exists(filename):
+1 −1
Original line number Diff line number Diff line
@@ -154,7 +154,7 @@ class UpscalerESRGAN(Upscaler):
        if "http" in path:
            filename = load_file_from_url(
                url=self.model_url,
                model_dir=self.model_path,
                model_dir=self.model_download_path,
                file_name=f"{self.model_name}.pth",
                progress=True,
            )
+5 −2
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
        if model_url is not None and len(output) == 0:
            if download_name is not None:
                from basicsr.utils.download_util import load_file_from_url
                dl = load_file_from_url(model_url, model_path, True, download_name)
                dl = load_file_from_url(model_url, places[0], True, download_name)
                output.append(dl)
            else:
                output.append(model_url)
@@ -144,7 +144,10 @@ def load_upscalers():
    for cls in reversed(used_classes.values()):
        name = cls.__name__
        cmd_name = f"{name.lower().replace('upscaler', '')}_models_path"
        scaler = cls(commandline_options.get(cmd_name, None))
        commandline_model_path = commandline_options.get(cmd_name, None)
        scaler = cls(commandline_model_path)
        scaler.user_path = commandline_model_path
        scaler.model_download_path = commandline_model_path or scaler.model_path
        datas += scaler.scalers

    shared.sd_upscalers = sorted(
Loading