Commit c715ef04 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix for incorrect model weight loading for #814

parent 965dcf44
Loading
Loading
Loading
Loading
+9 −0
Original line number Diff line number Diff line
@@ -245,6 +245,7 @@ class StableDiffusionModelHijack:

        model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
        m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

        self.clip = m.cond_stage_model

        if cmd_opts.opt_split_attention_v1:
@@ -263,6 +264,14 @@ class StableDiffusionModelHijack:

        self.layers = flatten(m)

    def undo_hijack(self, m):
        if type(m.cond_stage_model) == FrozenCLIPEmbedderWithCustomWords:
            m.cond_stage_model = m.cond_stage_model.wrapped

        model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
        if type(model_embeddings.token_embedding) == EmbeddingsWithFixes:
            model_embeddings.token_embedding = model_embeddings.token_embedding.wrapped

    def apply_circular(self, enable):
        if self.circular_enabled == enable:
            return
+5 −1
Original line number Diff line number Diff line
@@ -137,7 +137,7 @@ def load_model():


def reload_model_weights(sd_model, info=None):
    from modules import lowvram, devices
    from modules import lowvram, devices, sd_hijack
    checkpoint_info = info or select_checkpoint()

    if sd_model.sd_model_checkpint == checkpoint_info.filename:
@@ -148,8 +148,12 @@ def reload_model_weights(sd_model, info=None):
    else:
        sd_model.to(devices.cpu)

    sd_hijack.model_hijack.undo_hijack(sd_model)

    load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)

    sd_hijack.model_hijack.hijack(sd_model)

    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
        sd_model.to(devices.device)