Commit a07f054c authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur
Browse files

Add missing info on hypernetwork/embedding model log

parent ab05a74e
Loading
Loading
Loading
Loading
+21 −10
Original line number Original line Diff line number Diff line
@@ -361,6 +361,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
        images_dir = None
        images_dir = None


    hypernetwork = shared.loaded_hypernetwork
    hypernetwork = shared.loaded_hypernetwork
    checkpoint = sd_models.select_checkpoint()


    ititial_step = hypernetwork.step or 0
    ititial_step = hypernetwork.step or 0
    if ititial_step > steps:
    if ititial_step > steps:
@@ -449,9 +450,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log


        if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
        if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
            # Before saving, change name to match current checkpoint.
            # Before saving, change name to match current checkpoint.
            hypernetwork.name = f'{hypernetwork_name}-{steps_done}'
            hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name_every}.pt')
            hypernetwork.save(last_saved_file)
            save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)


        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{previous_mean_loss:.7f}",
            "loss": f"{previous_mean_loss:.7f}",
@@ -512,13 +513,23 @@ Last saved image: {html.escape(last_saved_image)}<br/>
"""
"""
        
        
    report_statistics(loss_dict)
    report_statistics(loss_dict)
    checkpoint = sd_models.select_checkpoint()


    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)

    return hypernetwork, filename

def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
    old_hypernetwork_name = hypernetwork.name
    old_sd_checkpoint = hypernetwork.sd_checkpoint if hasattr(hypernetwork, "sd_checkpoint") else None
    old_sd_checkpoint_name = hypernetwork.sd_checkpoint_name if hasattr(hypernetwork, "sd_checkpoint_name") else None
    try:
        hypernetwork.sd_checkpoint = checkpoint.hash
        hypernetwork.sd_checkpoint = checkpoint.hash
        hypernetwork.sd_checkpoint_name = checkpoint.model_name
        hypernetwork.sd_checkpoint_name = checkpoint.model_name
    # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
        hypernetwork.name = hypernetwork_name
        hypernetwork.name = hypernetwork_name
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
        hypernetwork.save(filename)
        hypernetwork.save(filename)

    except:
    return hypernetwork, filename
        hypernetwork.sd_checkpoint = old_sd_checkpoint
        hypernetwork.sd_checkpoint_name = old_sd_checkpoint_name
        hypernetwork.name = old_hypernetwork_name
        raise
+26 −13
Original line number Original line Diff line number Diff line
@@ -119,7 +119,7 @@ class EmbeddingDatabase:
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            vec = emb.detach().to(devices.device, dtype=torch.float32)
            embedding = Embedding(vec, name)
            embedding = Embedding(vec, name)
            embedding.step = data.get('step', None)
            embedding.step = data.get('step', None)
            embedding.sd_checkpoint = data.get('hash', None)
            embedding.sd_checkpoint = data.get('sd_checkpoint', None)
            embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
            embedding.sd_checkpoint_name = data.get('sd_checkpoint_name', None)
            self.register_embedding(embedding, shared.sd_model)
            self.register_embedding(embedding, shared.sd_model)


@@ -259,6 +259,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    hijack = sd_hijack.model_hijack
    hijack = sd_hijack.model_hijack


    embedding = hijack.embedding_db.word_embeddings[embedding_name]
    embedding = hijack.embedding_db.word_embeddings[embedding_name]
    checkpoint = sd_models.select_checkpoint()


    ititial_step = embedding.step or 0
    ititial_step = embedding.step or 0
    if ititial_step > steps:
    if ititial_step > steps:
@@ -314,9 +315,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc


        if embedding_dir is not None and steps_done % save_embedding_every == 0:
        if embedding_dir is not None and steps_done % save_embedding_every == 0:
            # Before saving, change name to match current checkpoint.
            # Before saving, change name to match current checkpoint.
            embedding.name = f'{embedding_name}-{steps_done}'
            embedding_name_every = f'{embedding_name}-{steps_done}'
            last_saved_file = os.path.join(embedding_dir, f'{embedding.name}.pt')
            last_saved_file = os.path.join(embedding_dir, f'{embedding_name_every}.pt')
            embedding.save(last_saved_file)
            save_embedding(embedding, checkpoint, embedding_name_every, last_saved_file, remove_cached_checksum=True)
            embedding_yet_to_be_embedded = True
            embedding_yet_to_be_embedded = True


        write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
        write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
@@ -397,14 +398,26 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
</p>
"""
"""


    checkpoint = sd_models.select_checkpoint()
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
    save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True)

    return embedding, filename


def save_embedding(embedding, checkpoint, embedding_name, filename, remove_cached_checksum=True):
    old_embedding_name = embedding.name
    old_sd_checkpoint = embedding.sd_checkpoint if hasattr(embedding, "sd_checkpoint") else None
    old_sd_checkpoint_name = embedding.sd_checkpoint_name if hasattr(embedding, "sd_checkpoint_name") else None
    old_cached_checksum = embedding.cached_checksum if hasattr(embedding, "cached_checksum") else None
    try:
        embedding.sd_checkpoint = checkpoint.hash
        embedding.sd_checkpoint = checkpoint.hash
        embedding.sd_checkpoint_name = checkpoint.model_name
        embedding.sd_checkpoint_name = checkpoint.model_name
        if remove_cached_checksum:
            embedding.cached_checksum = None
            embedding.cached_checksum = None
    # Before saving for the last time, change name back to base name (as opposed to the save_embedding_every step-suffixed naming convention).
        embedding.name = embedding_name
        embedding.name = embedding_name
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding.name}.pt')
        embedding.save(filename)
        embedding.save(filename)

    except:
    return embedding, filename
        embedding.sd_checkpoint = old_sd_checkpoint
        embedding.sd_checkpoint_name = old_sd_checkpoint_name
        embedding.name = old_embedding_name
        embedding.cached_checksum = old_cached_checksum
        raise