Commit 19818f02 authored by timntorres's avatar timntorres Committed by AUTOMATIC1111
Browse files

Match hypernet name with filename in all cases.

parent 51e3dc9c
Loading
Loading
Loading
Loading
+7 −1
Original line number Diff line number Diff line
@@ -340,7 +340,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
        pbar.set_description(f"loss: {mean_loss:.7f}")

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
            temp = hypernetwork.name
            # Before saving, change name to match current checkpoint.
            hypernetwork.name = f'{hypernetwork_name}-{hypernetwork.step}'
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork.name}.pt')
            hypernetwork.save(last_saved_file)

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
@@ -405,6 +408,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>

    hypernetwork.sd_checkpoint = checkpoint.hash
    hypernetwork.sd_checkpoint_name = checkpoint.model_name
    # Before saving for the last time, change name back to the base name (as opposed to the save_hypernetwork_every step-suffixed naming convention).
    hypernetwork.name = hypernetwork_name
    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork.name}.pt')
    hypernetwork.save(filename)

    return hypernetwork, filename