Commit 0c3feb20 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

disable torch weight initialization and CLIP downloading/reading checkpoint to...

disable torch weight initialization and CLIP downloading/reading checkpoint to speedup creating sd model from config
parent ef75c980
Loading
Loading
Loading
Loading
+44 −0
Original line number Diff line number Diff line
import ldm.modules.encoders.modules
import open_clip
import torch


class DisableInitialization:
    """
    When an object of this class enters a `with` block, it starts preventing torch's layer initialization
    functions from working, and changes CLIP and OpenCLIP to not download model weights. When it leaves,
    reverts everything to how it was.

    Use like this:
    ```
    with DisableInitialization():
        do_things()
    ```
    """

    def __enter__(self):
        def do_nothing(*args, **kwargs):
            pass

        def create_model_and_transforms_without_pretrained(*args, pretrained=None, **kwargs):
            return self.create_model_and_transforms(*args, pretrained=None, **kwargs)

        def CLIPTextModel_from_pretrained(pretrained_model_name_or_path, *model_args, **kwargs):
            return self.CLIPTextModel_from_pretrained(None, *model_args, config=pretrained_model_name_or_path, state_dict={}, **kwargs)

        self.init_kaiming_uniform = torch.nn.init.kaiming_uniform_
        self.init_no_grad_normal = torch.nn.init._no_grad_normal_
        self.create_model_and_transforms = open_clip.create_model_and_transforms
        self.CLIPTextModel_from_pretrained = ldm.modules.encoders.modules.CLIPTextModel.from_pretrained

        torch.nn.init.kaiming_uniform_ = do_nothing
        torch.nn.init._no_grad_normal_ = do_nothing
        open_clip.create_model_and_transforms = create_model_and_transforms_without_pretrained
        ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = CLIPTextModel_from_pretrained

    def __exit__(self, exc_type, exc_val, exc_tb):
        torch.nn.init.kaiming_uniform_ = self.init_kaiming_uniform
        torch.nn.init._no_grad_normal_ = self.init_no_grad_normal
        open_clip.create_model_and_transforms = self.create_model_and_transforms
        ldm.modules.encoders.modules.CLIPTextModel.from_pretrained = self.CLIPTextModel_from_pretrained
+3 −2
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ import ldm.modules.midas as midas

from ldm.util import instantiate_from_config

from modules import shared, modelloader, devices, script_callbacks, sd_vae
from modules import shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization
from modules.paths import models_path
from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inpainting

@@ -319,6 +319,7 @@ def load_model(checkpoint_info=None):
    if shared.cmd_opts.no_half:
        sd_config.model.params.unet_config.params.use_fp16 = False

    with sd_disable_initialization.DisableInitialization():
        sd_model = instantiate_from_config(sd_config.model)

    load_model_weights(sd_model, checkpoint_info)