Commit 6f0abbb7 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

textual inversion support for SDXL

parent 4ca9f70b
Loading
Loading
Loading
Loading
+5 −3
Original line number Diff line number Diff line
@@ -197,7 +197,7 @@ class StableDiffusionModelHijack:
                    conditioner.embedders[i] = sd_hijack_clip.FrozenCLIPEmbedderForSDXLWithCustomWords(embedder, self)
                    text_cond_models.append(conditioner.embedders[i])
                if typename == 'FrozenOpenCLIPEmbedder2':
                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self)
                    embedder.model.token_embedding = EmbeddingsWithFixes(embedder.model.token_embedding, self, textual_inversion_key='clip_g')
                    conditioner.embedders[i] = sd_hijack_open_clip.FrozenOpenCLIPEmbedder2WithCustomWords(embedder, self)
                    text_cond_models.append(conditioner.embedders[i])

@@ -292,10 +292,11 @@ class StableDiffusionModelHijack:


class EmbeddingsWithFixes(torch.nn.Module):
    def __init__(self, wrapped, embeddings):
    def __init__(self, wrapped, embeddings, textual_inversion_key='clip_l'):
        super().__init__()
        self.wrapped = wrapped
        self.embeddings = embeddings
        self.textual_inversion_key = textual_inversion_key

    def forward(self, input_ids):
        batch_fixes = self.embeddings.fixes
@@ -309,7 +310,8 @@ class EmbeddingsWithFixes(torch.nn.Module):
        vecs = []
        for fixes, tensor in zip(batch_fixes, inputs_embeds):
            for offset, embedding in fixes:
                emb = devices.cond_cast_unet(embedding.vec)
                vec = embedding.vec[self.textual_inversion_key] if isinstance(embedding.vec, dict) else embedding.vec
                emb = devices.cond_cast_unet(vec)
                emb_len = min(tensor.shape[0] - offset - 1, emb.shape[0])
                tensor = torch.cat([tensor[0:offset + 1], emb[0:emb_len], tensor[offset + 1 + emb_len:]])

+1 −1
Original line number Diff line number Diff line
@@ -161,7 +161,7 @@ class FrozenCLIPEmbedderWithCustomWordsBase(torch.nn.Module):
                    position += 1
                    continue

                emb_len = int(embedding.vec.shape[0])
                emb_len = int(embedding.vectors)
                if len(chunk.tokens) + emb_len > self.chunk_length:
                    next_chunk()

+9 −0
Original line number Diff line number Diff line
@@ -56,6 +56,14 @@ def encode_embedding_init_text(self: sgm.modules.GeneralConditioner, init_text,
    return torch.cat(res, dim=1)


def tokenize(self: sgm.modules.GeneralConditioner, texts):
    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'tokenize')]:
        return embedder.tokenize(texts)

    raise AssertionError('no tokenizer available')



def process_texts(self, texts):
    for embedder in [embedder for embedder in self.embedders if hasattr(embedder, 'process_texts')]:
        return embedder.process_texts(texts)
@@ -68,6 +76,7 @@ def get_target_prompt_token_count(self, token_count):

# those additions to GeneralConditioner make it possible to use it as model.cond_stage_model from SD1.5 in exist
sgm.modules.GeneralConditioner.encode_embedding_init_text = encode_embedding_init_text
sgm.modules.GeneralConditioner.tokenize = tokenize
sgm.modules.GeneralConditioner.process_texts = process_texts
sgm.modules.GeneralConditioner.get_target_prompt_token_count = get_target_prompt_token_count

+14 −5
Original line number Diff line number Diff line
@@ -181,29 +181,38 @@ class EmbeddingDatabase:
        else:
            return


        # textual inversion embeddings
        if 'string_to_param' in data:
            param_dict = data['string_to_param']
            param_dict = getattr(param_dict, '_parameters', param_dict)  # fix for torch 1.12.1 loading saved file from torch 1.11
            assert len(param_dict) == 1, 'embedding file has multiple terms in it'
            emb = next(iter(param_dict.items()))[1]
        # diffuser concepts
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor:
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        elif type(data) == dict and 'clip_g' in data and 'clip_l' in data:  # SDXL embedding
            vec = {k: v.detach().to(devices.device, dtype=torch.float32) for k, v in data.items()}
            shape = data['clip_g'].shape[-1] + data['clip_l'].shape[-1]
            vectors = data['clip_g'].shape[0]
        elif type(data) == dict and type(next(iter(data.values()))) == torch.Tensor: # diffuser concepts
            assert len(data.keys()) == 1, 'embedding file has multiple terms in it'

            emb = next(iter(data.values()))
            if len(emb.shape) == 1:
                emb = emb.unsqueeze(0)
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            shape = vec.shape[-1]
            vectors = vec.shape[0]
        else:
            raise Exception(f"Couldn't identify {filename} as neither textual inversion embedding nor diffuser concept.")

        vec = emb.detach().to(devices.device, dtype=torch.float32)
        embedding = Embedding(vec, name)
        embedding.step = data.get('step', None)
        embedding.sd_checkpoint = data.get('sd_checkpoint', None)
        embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
        embedding.vectors = vec.shape[0]
        embedding.shape = vec.shape[-1]
        embedding.vectors = vectors
        embedding.shape = shape
        embedding.filename = path
        embedding.set_hash(hashes.sha256(embedding.filename, "textual_inversion/" + name) or '')