Commit 40b56c92 authored by AngelBottomless's avatar AngelBottomless Committed by AUTOMATIC1111
Browse files

cleanup some code

parent b297cc33
Loading
Loading
Loading
Loading
+3 −11
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum

from collections import defaultdict, deque
from statistics import stdev, mean

class HypernetworkModule(torch.nn.Module):
@@ -269,15 +270,6 @@ def stack_conds(conds):
    return torch.stack(conds)


def log_statistics(loss_info:dict, key, value):
    if key not in loss_info:
        loss_info[key] = [value]
    else:
        loss_info[key].append(value)
        if len(loss_info[key]) > 1024:
            loss_info[key].pop(0)


def statistics(data):
    total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
    recent_data = data[-32:]
@@ -341,7 +333,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
        weight.requires_grad = True

    size = len(ds.indexes)
    loss_dict = {}
    loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    losses = torch.zeros((size,))
    previous_mean_loss = 0
    print("Mean loss of {} elements".format(size))
@@ -383,7 +375,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

            losses[hypernetwork.step % losses.shape[0]] = loss.item()
            for entry in entries:
                log_statistics(loss_dict, entry.filename, loss.item())
                loss_dict[entry.filename].append(loss.item())
                
            optimizer.zero_grad()
            weights[0].grad = None