Commit 668d7e9b authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make it possible to load SD1 checkpoints without CLIP

parent 3e0f9a75
Loading
Loading
Loading
Loading
+10 −7
Original line number Diff line number Diff line
@@ -20,8 +20,9 @@ class DisableInitialization:
    ```
    """

    def __init__(self):
    def __init__(self, disable_clip=True):
        self.replaced = []
        self.disable_clip = disable_clip

    def replace(self, obj, field, func):
        original = getattr(obj, field, None)
@@ -75,6 +76,8 @@ class DisableInitialization:
        self.replace(torch.nn.init, 'kaiming_uniform_', do_nothing)
        self.replace(torch.nn.init, '_no_grad_normal_', do_nothing)
        self.replace(torch.nn.init, '_no_grad_uniform_', do_nothing)

        if self.disable_clip:
            self.create_model_and_transforms = self.replace(open_clip, 'create_model_and_transforms', create_model_and_transforms_without_pretrained)
            self.CLIPTextModel_from_pretrained = self.replace(ldm.modules.encoders.modules.CLIPTextModel, 'from_pretrained', CLIPTextModel_from_pretrained)
            self.transformers_modeling_utils_load_pretrained_model = self.replace(transformers.modeling_utils.PreTrainedModel, '_load_pretrained_model', transformers_modeling_utils_load_pretrained_model)
+5 −1
Original line number Diff line number Diff line
@@ -354,6 +354,9 @@ def repair_config(sd_config):
        sd_config.model.params.unet_config.params.use_fp16 = True


sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):
    from modules import lowvram, sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()
@@ -374,6 +377,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
        state_dict = get_checkpoint_state_dict(checkpoint_info, timer)

    checkpoint_config = sd_models_config.find_checkpoint_config(state_dict, checkpoint_info)
    clip_is_included_into_sd = sd1_clip_weight in state_dict or sd2_clip_weight in state_dict

    timer.record("find config")

@@ -386,7 +390,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_

    sd_model = None
    try:
        with sd_disable_initialization.DisableInitialization():
        with sd_disable_initialization.DisableInitialization(disable_clip=clip_is_included_into_sd):
            sd_model = instantiate_from_config(sd_config.model)
    except Exception as e:
        pass