Commit 24694e59 authored by AngelBottomless's avatar AngelBottomless Committed by AUTOMATIC1111
Browse files

Update hypernetwork.py

parent 321bacc6
Loading
Loading
Loading
Loading
+44 −11
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler
from torch import einsum

from statistics import stdev, mean

class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0
@@ -268,6 +269,32 @@ def stack_conds(conds):
    return torch.stack(conds)


def log_statistics(loss_info:dict, key, value):
    if key not in loss_info:
        loss_info[key] = [value]
    else:
        loss_info[key].append(value)
        if len(loss_info) > 1024:
            loss_info.pop(0)


def statistics(data):
    total_information = f"loss:{mean(data):.3f}"+u"\u00B1"+f"({stdev(data)/ (len(data)**0.5):.3f})"
    recent_data = data[-32:]
    recent_information = f"recent 32 loss:{mean(recent_data):.3f}"+u"\u00B1"+f"({stdev(recent_data)/ (len(recent_data)**0.5):.3f})"
    return total_information, recent_information


def report_statistics(loss_info:dict):
    keys = sorted(loss_info.keys(), key=lambda x: sum(loss_info[x]) / len(loss_info[x]))
    for key in keys:
        info, recent = statistics(loss_info[key])
        print("Loss statistics for file " + key)
        print(info)
        print(recent)



def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images
@@ -310,7 +337,11 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    for weight in weights:
        weight.requires_grad = True

    losses = torch.zeros((32,))
    size = len(ds.indexes)
    loss_dict = {}
    losses = torch.zeros((size,))
    previous_mean_loss = 0
    print("Mean loss of {} elements".format(size))

    last_saved_file = "<none>"
    last_saved_image = "<none>"
@@ -329,6 +360,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
    for i, entries in pbar:
        hypernetwork.step = i + ititial_step
        if loss_dict and i % size == 0:
            previous_mean_loss = sum(i[-1] for i in loss_dict.values()) / len(loss_dict)
            
        scheduler.apply(optimizer, hypernetwork.step)
        if scheduler.finished:
@@ -346,6 +379,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            del c

            losses[hypernetwork.step % losses.shape[0]] = loss.item()
            for entry in entries:
                log_statistics(loss_dict, entry.filename, loss.item())
                
            optimizer.zero_grad()
            weights[0].grad = None
@@ -359,10 +394,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

            optimizer.step()

        mean_loss = losses.mean()
        if torch.isnan(mean_loss):
        if torch.isnan(losses[hypernetwork.step % losses.shape[0]]):
            raise RuntimeError("Loss diverged.")
        pbar.set_description(f"loss: {mean_loss:.7f}")
        pbar.set_description(f"dataset loss: {previous_mean_loss:.7f}")

        if hypernetwork.step > 0 and hypernetwork_dir is not None and hypernetwork.step % save_hypernetwork_every == 0:
            # Before saving, change name to match current checkpoint.
@@ -371,7 +405,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            hypernetwork.save(last_saved_file)

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{mean_loss:.7f}",
            "loss": f"{previous_mean_loss:.7f}",
            "learn_rate": scheduler.learn_rate
        })

@@ -420,7 +454,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

        shared.state.textinfo = f"""
<p>
Loss: {mean_loss:.7f}<br/>
Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
@@ -428,6 +462,7 @@ Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
        
    report_statistics(loss_dict)
    checkpoint = sd_models.select_checkpoint()

    hypernetwork.sd_checkpoint = checkpoint.hash
@@ -438,5 +473,3 @@ Last saved image: {html.escape(last_saved_image)}<br/>
    hypernetwork.save(filename)

    return hypernetwork, filename