Commit ce3f639e authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add more stuff to ignore when creating model from config

prevent .vae.safetensors files from being listed as stable diffusion models
parent 0c3feb20
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -10,7 +10,7 @@ from modules.upscaler import Upscaler
from modules.paths import script_path, models_path


def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None) -> list:
def load_models(model_path: str, model_url: str = None, command_path: str = None, ext_filter=None, download_name=None, ext_blacklist=None) -> list:
    """
    A one-and done loader to try finding the desired models in specified directories.

@@ -45,6 +45,8 @@ def load_models(model_path: str, model_url: str = None, command_path: str = None
                    full_path = file
                    if os.path.isdir(full_path):
                        continue
                    if ext_blacklist is not None and any([full_path.endswith(x) for x in ext_blacklist]):
                        continue
                    if len(ext_filter) != 0:
                        model_name, extension = os.path.splitext(file)
                        if extension not in ext_filter:
+25 −4
Original line number Diff line number Diff line
import ldm.modules.encoders.modules
import open_clip
import torch
import transformers.utils.hub


class DisableInitialization:
    """
    When an object of this class enters a `with` block, it starts preventing torch's layer initialization
    functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves,
    reverts everything to how it was.
    When an object of this class enters a `with` block, it starts:
    - preventing torch's layer initialization functions from working
    - changes CLIP and OpenCLIP to not download model weights
    - changes CLIP to not make requests to check if there is a new version of a file you already have

    Use like this:
    When it leaves the block, it reverts everything to how it was before.

    Use it like this:
    ```
    with DisableInitialization():
        do_things()
@@ -26,19 +30,36 @@ class DisableInitialization:
        def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
            return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)

        def transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs):

            # this file is always 404, prevent making request
            if url == 'https://huggingface.co/openai/clip-vit-large-patch14/resolve/main/added_tokens.json':
                raise transformers.utils.hub.EntryNotFoundError

            try:
                return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=True, **kwargs)
            except Exception as e:
                return self.transformers_utils_hub_get_from_cache(url, *args, local_files_only=False, **kwargs)

        self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
        self.init_no_grad_normal = torch.nn.init._no_grad_normal_
        self.init_no_grad_uniform_ = torch.nn.init._no_grad_uniform_
        self.create_model_and_transforms = open_clip.create_model_and_transforms
        self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained
        self.transformers_utils_hub_get_from_cache = transformers.utils.hub.get_from_cache

        torch.nn.init.kaiming_uniform_ = do_nothing
        torch.nn.init._no_grad_normal_ = do_nothing
        torch.nn.init._no_grad_uniform_ = do_nothing
        open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
        ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained
        transformers.utils.hub.get_from_cache = transformers_utils_hub_get_from_cache

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
        torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
        torch.nn.init._no_grad_uniform_ = self.init_no_grad_uniform_
        open_clip.create_model_and_transforms = self.create_model_and_transforms
        ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
        transformers.utils.hub.get_from_cache = self.transformers_utils_hub_get_from_cache
+28 −4
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import collections
import os.path
import sys
import gc
import time
from collections import namedtuple
import torch
import re
@@ -61,7 +62,7 @@ def find_checkpoint_config(info):

def list_models():
    checkpoints_list.clear()
    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"], ext_blacklist=[".vae.safetensors"])

    def modeltitle(path, shorthash):
        abspath = os.path.abspath(path)
@@ -288,6 +289,17 @@ def enable_midas_autodownload():
    midas.api.load_model = load_model_wrapper


class Timer:
    def __init__(self):
        self.start = time.time()

    def elapsed(self):
        end = time.time()
        res = end - self.start
        self.start = end
        return res


def load_model(checkpoint_info=None):
    from modules import lowvram, sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()
@@ -319,11 +331,17 @@ def load_model(checkpoint_info=None):
    if shared.cmd_opts.no_half:
        sd_config.model.params.unet_config.params.use_fp16 = False

    timer = Timer()

    with sd_disable_initialization.DisableInitialization():
        sd_model = instantiate_from_config(sd_config.model)

    elapsed_create = timer.elapsed()

    load_model_weights(sd_model, checkpoint_info)

    elapsed_load_weights = timer.elapsed()

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
    else:
@@ -338,7 +356,9 @@ def load_model(checkpoint_info=None):

    script_callbacks.model_loaded_callback(sd_model)

    print("Model loaded.")
    elapsed_the_rest = timer.elapsed()

    print(f"Model loaded in {elapsed_create + elapsed_load_weights + elapsed_the_rest:.1f}s ({elapsed_create:.1f}s create model, {elapsed_load_weights:.1f}s load weights).")

    return sd_model

@@ -371,6 +391,8 @@ def reload_model_weights(sd_model=None, info=None):

    sd_hijack.model_hijack.undo_hijack(sd_model)

    timer = Timer()

    try:
        load_model_weights(sd_model, checkpoint_info)
    except Exception as e:
@@ -384,6 +406,8 @@ def reload_model_weights(sd_model=None, info=None):
        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
            sd_model.to(devices.device)

    print("Weights loaded.")
    elapsed = timer.elapsed()

    print(f"Weights loaded in {elapsed:.1f}s.")

    return sd_model