Commit bdbe0982 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

changed embedding accepted shape detection to use existing code and support...

changed embedding accepted shape detection to use existing code and support the new alt-diffusion model, and reformatted messages a bit #6149
parent c24a314c
Loading
Loading
Loading
Loading
+6 −24
Original line number Diff line number Diff line
@@ -80,23 +80,8 @@ class EmbeddingDatabase:
        return embedding

    def get_expected_shape(self):
        expected_shape = -1 # initialize with unknown
        idx = torch.tensor(0).to(shared.device)
        if expected_shape == -1:
            try: # matches sd15 signature
                first_embedding = shared.sd_model.cond_stage_model.wrapped.transformer.text_model.embeddings.token_embedding.wrapped(idx)
                expected_shape = first_embedding.shape[0]
            except:
                pass
        if expected_shape == -1:
            try: # matches sd20 signature
                first_embedding = shared.sd_model.cond_stage_model.wrapped.model.token_embedding.wrapped(idx)
                expected_shape = first_embedding.shape[0]
            except:
                pass
        if expected_shape == -1:
            print('Could not determine expected embeddings shape from model')
        return expected_shape
        vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
        return vec.shape[1]

    def load_textual_inversion_embeddings(self, force_reload = False):
        mt = os.path.getmtime(self.embeddings_dir)
@@ -112,8 +97,6 @@ class EmbeddingDatabase:
        def process_file(path, filename):
            name = os.path.splitext(filename)[0]

            data = []

            if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
                embed_image = Image.open(path)
                if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
@@ -150,11 +133,10 @@ class EmbeddingDatabase:
            embedding.vectors = vec.shape[0]
            embedding.shape = vec.shape[-1]

            if (self.expected_shape == -1) or (self.expected_shape == embedding.shape):
            if self.expected_shape == -1 or self.expected_shape == embedding.shape:
                self.register_embedding(embedding, shared.sd_model)
            else:
                self.skipped_embeddings.append(name)
                # print('Skipping embedding {name}: shape was {shape} expected {expected}'.format(name = name, shape = embedding.shape, expected = self.expected_shape))

        for fn in os.listdir(self.embeddings_dir):
            try:
@@ -169,9 +151,9 @@ class EmbeddingDatabase:
                print(traceback.format_exc(), file=sys.stderr)
                continue

        print("Textual inversion embeddings {num} loaded: {val}".format(num = len(self.word_embeddings), val = ', '.join(self.word_embeddings.keys())))
        if (len(self.skipped_embeddings) > 0):
            print("Textual inversion embeddings {num} skipped: {val}".format(num = len(self.skipped_embeddings), val = ', '.join(self.skipped_embeddings)))
        print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
        if len(self.skipped_embeddings) > 0:
            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")

    def find_embedding_at_position(self, tokens, offset):
        token = tokens[offset]