Commit 348f89c8 authored by AngelBottomless's avatar AngelBottomless Committed by AUTOMATIC1111
Browse files

statistics for pbar

parent 40b56c92
Loading
Loading
Loading
Loading
+10 −2
Original line number Diff line number Diff line
@@ -335,6 +335,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    size = len(ds.indexes)
    loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    losses = torch.zeros((size,))
    previous_mean_losses = [0]
    previous_mean_loss = 0
    print("Mean loss of {} elements".format(size))

@@ -356,7 +357,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    for i, entries in pbar:
        hypernetwork.step = i + ititial_step
        if len(loss_dict) > 0:
            previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
            previous_mean_losses = [i[-1] for i in loss_dict.values()]
            previous_mean_loss = mean(previous_mean_losses)
            
        scheduler.apply(optimizer, hypernetwork.step)
        if scheduler.finished:
@@ -391,7 +393,13 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

        if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
            raise RuntimeError("Loss diverged.")
        pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")
        
        if len(previous_mean_losses) > 1:
            std = stdev(previous_mean_losses)
        else:
            std = 0
        dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
        pbar.set_description(dataset_loss_info)

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
            # Before saving, change name to match current checkpoint.