Unverified Commit 7bbd984d authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #6253 from Shondoit/ti-optim

Save Optimizer next to TI embedding
parents 545ae8cb bddebe09
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -356,7 +356,7 @@ options_templates.update(options_section(('system', "System"), {
options_templates.update(options_section(('training', "Training"), {
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training of embedding or HN can be resumed with the matching optim file."),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
    "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
+32 −8
Original line number Diff line number Diff line
@@ -28,6 +28,7 @@ class Embedding:
        self.cached_checksum = None
        self.sd_checkpoint = None
        self.sd_checkpoint_name = None
        self.optimizer_state_dict = None

    def save(self, filename):
        embedding_data = {
@@ -41,6 +42,13 @@ class Embedding:

        torch.save(embedding_data, filename)

        if shared.opts.save_optimizer_state and self.optimizer_state_dict is not None:
            optimizer_saved_dict = {
                'hash': self.checksum(),
                'optimizer_state_dict': self.optimizer_state_dict,
            }
            torch.save(optimizer_saved_dict, filename + '.optim')

    def checksum(self):
        if self.cached_checksum is not None:
            return self.cached_checksum
@@ -95,9 +103,10 @@ class EmbeddingDatabase:
        self.expected_shape = self.get_expected_shape()

        def process_file(path, filename):
            name = os.path.splitext(filename)[0]
            name, ext = os.path.splitext(filename)
            ext = ext.upper()

            if os.path.splitext(filename.upper())[-1] in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
            if ext in ['.PNG', '.WEBP', '.JXL', '.AVIF']:
                embed_image = Image.open(path)
                if hasattr(embed_image, 'text') and 'sd-ti-embedding' in embed_image.text:
                    data = embedding_from_b64(embed_image.text['sd-ti-embedding'])
@@ -105,8 +114,10 @@ class EmbeddingDatabase:
                else:
                    data = extract_image_data_embed(embed_image)
                    name = data.get('name', name)
            else:
            elif ext in ['.BIN', '.PT']:
                data = torch.load(path, map_location="cpu")
            else:
                return

            # textual inversion embeddings
            if 'string_to_param' in data:
@@ -301,6 +312,20 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_

    embedding.vec.requires_grad = True
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
    if shared.opts.save_optimizer_state:
        optimizer_state_dict = None
        if os.path.exists(filename + '.optim'):
            optimizer_saved_dict = torch.load(filename + '.optim', map_location='cpu')
            if embedding.checksum() == optimizer_saved_dict.get('hash', None):
                optimizer_state_dict = optimizer_saved_dict.get('optimizer_state_dict', None)
    
        if optimizer_state_dict is not None:
            optimizer.load_state_dict(optimizer_state_dict)
            print("Loaded existing optimizer from checkpoint")
        else:
            print("No saved optimizer exists in checkpoint")


    scaler = torch.cuda.amp.GradScaler()

    batch_size = ds.batch_size
@@ -367,9 +392,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                    # Before saving, change name to match current checkpoint.
                    embedding_name_every = f'{embedding_name}-{steps_done}'
                    last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
                    #if shared.opts.save_optimizer_state:
                        #embedding.optimizer_state_dict = optimizer.state_dict()
                    save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
                    save_embedding(embedding, optimizer, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
                    embedding_yet_to_be_embedded = True

                write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, steps_per_epoch, {
@@ -459,7 +482,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
        filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
        save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)
        save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True)
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
        pass
@@ -471,7 +494,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>

    return embedding, filename

def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
def save_embedding(embedding, optimizer, checkpoint, embedding_name, filename, remove_cached_checksum=True):
    old_embedding_name = embedding.name
    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
@@ -482,6 +505,7 @@ def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cache
        if remove_cached_checksum:
            embedding.cached_checksum = None
        embedding.name = embedding_name
        embedding.optimizer_state_dict = optimizer.state_dict()
        embedding.save(filename)
    except:
        embedding.sd_checkpoint = old_sd_checkpoint