Commit 085427de authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make it possible for extensions/scripts to add their own embedding directories

parent a0c87f1f
Loading
Loading
Loading
Loading
+4 −3
Original line number Diff line number Diff line
@@ -83,10 +83,12 @@ class StableDiffusionModelHijack:
    clip = None
    optimization_method = None

    embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase(cmd_opts.embeddings_dir)
    embedding_db = modules.textual_inversion.textual_inversion.EmbeddingDatabase()

    def hijack(self, m):
    def __init__(self):
        self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)

    def hijack(self, m):
        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
            model_embeddings = m.cond_stage_model.roberta.embeddings
            model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.word_embeddings, self)
@@ -117,7 +119,6 @@ class StableDiffusionModelHijack:
        self.layers = flatten(m)

    def undo_hijack(self, m):

        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
            m.cond_stage_model = m.cond_stage_model.wrapped 

+104 −66
Original line number Diff line number Diff line
@@ -66,17 +66,41 @@ class Embedding:
        return self.cached_checksum


class DirWithTextualInversionEmbeddings:
    def __init__(self, path):
        self.path = path
        self.mtime = None

    def has_changed(self):
        if not os.path.isdir(self.path):
            return False

        mt = os.path.getmtime(self.path)
        if self.mtime is None or mt > self.mtime:
            return True

    def update(self):
        if not os.path.isdir(self.path):
            return

        self.mtime = os.path.getmtime(self.path)


class EmbeddingDatabase:
    def __init__(self, embeddings_dir):
    def __init__(self):
        self.ids_lookup = {}
        self.word_embeddings = {}
        self.skipped_embeddings = {}
        self.dir_mtime = None
        self.embeddings_dir = embeddings_dir
        self.expected_shape = -1
        self.embedding_dirs = {}

    def register_embedding(self, embedding, model):
    def add_embedding_dir(self, path):
        self.embedding_dirs[path] = DirWithTextualInversionEmbeddings(path)

    def clear_embedding_dirs(self):
        self.embedding_dirs.clear()

    def register_embedding(self, embedding, model):
        self.word_embeddings[embedding.name] = embedding

        ids = model.cond_stage_model.tokenize([embedding.name])[0]
@@ -93,18 +117,7 @@ class EmbeddingDatabase:
        vec = shared.sd_model.cond_stage_model.encode_embedding_init_text(",", 1)
        return vec.shape[1]

    def load_textual_inversion_embeddings(self, force_reload = False):
        mt = os.path.getmtime(self.embeddings_dir)
        if not force_reload and self.dir_mtime is not None and mt <= self.dir_mtime:
            return

        self.dir_mtime = mt
        self.ids_lookup.clear()
        self.word_embeddings.clear()
        self.skipped_embeddings.clear()
        self.expected_shape = self.get_expected_shape()

        def process_file(path, filename):
    def load_from_file(self, path, filename):
        name, ext = os.path.splitext(filename)
        ext = ext.upper()

@@ -155,7 +168,11 @@ class EmbeddingDatabase:
        else:
            self.skipped_embeddings[name] = embedding

        for root, dirs, fns in os.walk(self.embeddings_dir):
    def load_from_dir(self, embdir):
        if not os.path.isdir(embdir.path):
            return

        for root, dirs, fns in os.walk(embdir.path):
            for fn in fns:
                try:
                    fullfn = os.path.join(root, fn)
@@ -163,12 +180,32 @@ class EmbeddingDatabase:
                    if os.stat(fullfn).st_size == 0:
                        continue

                    process_file(fullfn, fn)
                    self.load_from_file(fullfn, fn)
                except Exception:
                    print(f"Error loading embedding {fn}:", file=sys.stderr)
                    print(traceback.format_exc(), file=sys.stderr)
                    continue

    def load_textual_inversion_embeddings(self, force_reload=False):
        if not force_reload:
            need_reload = False
            for path, embdir in self.embedding_dirs.items():
                if embdir.has_changed():
                    need_reload = True
                    break

            if not need_reload:
                return

        self.ids_lookup.clear()
        self.word_embeddings.clear()
        self.skipped_embeddings.clear()
        self.expected_shape = self.get_expected_shape()

        for path, embdir in self.embedding_dirs.items():
            self.load_from_dir(embdir)
            embdir.update()

        print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
        if len(self.skipped_embeddings) > 0:
            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")
@@ -259,6 +296,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
    if save_model_every or create_image_every:
        assert log_directory, "Log directory is empty"


def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    save_embedding_every = save_embedding_every or 0
    create_image_every = create_image_every or 0