Commit f34c7341 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

alt-diffusion integration

parent 3f401cdb
Loading
Loading
Loading
Loading
+0 −0

File moved.

+10 −8
Original line number Diff line number Diff line
@@ -5,7 +5,7 @@ import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr

from modules.sd_hijack_optimizations import invokeAI_mps_available

@@ -68,6 +68,7 @@ def fix_checkpoint():
    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward


class StableDiffusionModelHijack:
    fixes = None
    comments = []
@@ -79,19 +80,20 @@ class StableDiffusionModelHijack:

    def hijack(self, m):

        if shared.text_model_name == "XLMR-Large":
        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
            model_embeddings = m.cond_stage_model.roberta.embeddings
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
            m.cond_stage_model = sd_hijack_xlmr.FrozenXLMREmbedderWithCustomWords(m.cond_stage_model, self)

        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenCLIPEmbedder:
            model_embeddings = m.cond_stage_model.transformer.text_model.embeddings
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
            m.cond_stage_model = sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
            apply_optimizations()

        elif type(m.cond_stage_model) == ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder:
            m.cond_stage_model.model.token_embedding = EmbeddingsWithFixes(m.cond_stage_model.model.token_embedding, self)
            m.cond_stage_model = sd_hijack_open_clip.FrozenOpenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)

        apply_optimizations()

        self.clip = m.cond_stage_model
@@ -109,7 +111,7 @@ class StableDiffusionModelHijack:

    def undo_hijack(self, m):

        if shared.text_model_name == "XLMR-Large":
        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
            m.cond_stage_model = m.cond_stage_model.wrapped 

        elif type(m.cond_stage_model) == sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords:
+5 −9
Original line number Diff line number Diff line
@@ -4,7 +4,6 @@ import torch

from modules import prompt_parser, devices
from modules.shared import opts
import modules.shared as shared

def get_target_prompt_token_count(token_count):
    return math.ceil(max(token_count, 1) / 75) * 75
@@ -177,9 +176,6 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count

    def forward(self, text):
        if shared.text_model_name == "XLMR-Large":
            return self.wrapped.encode(text)

        use_old = opts.use_old_emphasis_implementation
        if use_old:
            batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text_old(text)
@@ -257,13 +253,13 @@ class FrozenCLIPEmbedderWithCustomWords(FrozenCLIPEmbedderWithCustomWordsBase):
    def __init__(self, wrapped, hijack):
        super().__init__(wrapped, hijack)
        self.tokenizer = wrapped.tokenizer
        if shared.text_model_name == "XLMR-Large":
            self.comma_token = None 
        else :
            self.comma_token = [v for k, v in self.tokenizer.get_vocab().items() if k == ',</w>'][0]

        vocab = self.tokenizer.get_vocab()

        self.comma_token = vocab.get(',</w>', None)

        self.token_mults = {}
        tokens_with_parens = [(k, v) for k, v in self.tokenizer.get_vocab().items() if '(' in k or ')' in k or '[' in k or ']' in k]
        tokens_with_parens = [(k, v) for k, v in vocab.items() if '(' in k or ')' in k or '[' in k or ']' in k]
        for text, ident in tokens_with_parens:
            mult = 1.0
            for c in text:
+34 −0
Original line number Diff line number Diff line
import open_clip.tokenizer
import torch

from modules import sd_hijack_clip, devices
from modules.shared import opts


class FrozenXLMREmbedderWithCustomWords(sd_hijack_clip.FrozenCLIPEmbedderWithCustomWords):
    def __init__(self, wrapped, hijack):
        super().__init__(wrapped, hijack)

        self.id_start = wrapped.config.bos_token_id
        self.id_end = wrapped.config.eos_token_id
        self.id_pad = wrapped.config.pad_token_id

        self.comma_token = self.tokenizer.get_vocab().get(',', None)  # alt diffusion doesn't have </w> bits for comma

    def encode_with_transformers(self, tokens):
        # there's no CLIP Skip here because all hidden layers have size of 1024 and the last one uses a
        # trained layer to transform those 1024 into 768 for unet; so you can't choose which transformer
        # layer to work with - you have to use the last

        attention_mask = (tokens != self.id_pad).to(device=tokens.device, dtype=torch.int64)
        features = self.wrapped(input_ids=tokens, attention_mask=attention_mask)
        z = features['projection_state']

        return z

    def encode_embedding_init_text(self, init_text, nvpt):
        embedding_layer = self.wrapped.roberta.embeddings
        ids = self.wrapped.tokenizer(init_text, max_length=nvpt, return_tensors="pt", add_special_tokens=False)["input_ids"]
        embedded = embedding_layer.token_embedding.wrapped(ids.to(devices.device)).squeeze(0)

        return embedded
Loading