Commit 8f591298 authored by Melan's avatar Melan
Browse files

Some changes to the tensorboard code and hypernetwork support

parent a6d593a6
Loading
Loading
Loading
Loading
+17 −1
Original line number Diff line number Diff line
@@ -4,6 +4,7 @@ import html
import os
import sys
import traceback
import tensorboard
import tqdm
import csv

@@ -18,7 +19,6 @@ import modules.textual_inversion.dataset
from modules.textual_inversion import textual_inversion
from modules.textual_inversion.learn_schedule import LearnRateScheduler


class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0

@@ -291,6 +291,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
    optimizer = torch.optim.AdamW(weights, lr=scheduler.learn_rate)

    if shared.opts.training_enable_tensorboard:
        tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)

    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
    for i, entries in pbar:
        hypernetwork.step = i + ititial_step
@@ -315,6 +318,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            optimizer.zero_grad()
            loss.backward()
            optimizer.step()

        mean_loss = losses.mean()
        if torch.isnan(mean_loss):
            raise RuntimeError("Loss diverged.")
@@ -324,6 +328,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            last_saved_file = os.path.join(hypernetwork_dir, f'{hypernetwork_name}-{hypernetwork.step}.pt')
            hypernetwork.save(last_saved_file)
        
        if shared.opts.training_enable_tensorboard:
            epoch_num = hypernetwork.step // len(ds)
            epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1
            
            textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss,
                global_step=hypernetwork.step, step=epoch_step, 
                learn_rate=scheduler.learn_rate, epoch_num=epoch_num)

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{mean_loss:.7f}",
            "learn_rate": scheduler.learn_rate
@@ -360,6 +372,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            processed = processing.process_images(p)
            image = processed.images[0] if len(processed.images)>0 else None

            if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
                textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}",
                      image, hypernetwork.step)

            if unload:
                shared.sd_model.cond_stage_model.to(devices.cpu)
                shared.sd_model.first_stage_model.to(devices.cpu)
+27 −18
Original line number Diff line number Diff line
@@ -201,16 +201,27 @@ def write_loss(log_directory, filename, step, epoch_len, values):
            **values,
        })

def tensorboard_setup(log_directory):
    os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
    return SummaryWriter(
            log_dir=os.path.join(log_directory, "tensorboard"),
            flush_secs=shared.opts.training_tensorboard_flush_every)

def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
    tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
    tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
    tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
    tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)

def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
    if shared.opts.training_enable_tensorboard:
    tensorboard_writer.add_scalar(tag=tag, 
        scalar_value=value, global_step=step)

def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
    if shared.opts.training_enable_tensorboard:
    # Convert a pil image to a torch tensor
    img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
        img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], len(pil_image.getbands()))
    img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], 
        len(pil_image.getbands()))
    img_tensor = img_tensor.permute((2, 0, 1))
                
    tensorboard_writer.add_image(tag, img_tensor, global_step=step)
@@ -268,10 +279,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)

    if shared.opts.training_enable_tensorboard:
        os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
        tensorboard_writer = SummaryWriter(
                log_dir=os.path.join(log_directory, "tensorboard"),
                flush_secs=shared.opts.training_tensorboard_flush_every)
        tensorboard_writer = tensorboard_setup(log_directory)

    pbar = tqdm.tqdm(enumerate(ds), total=steps-initial_step)
    for i, entries in pbar:
@@ -308,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
            embedding_yet_to_be_embedded = True

        if shared.opts.training_enable_tensorboard:
            tensorboard_add_scaler(tensorboard_writer, "Loss/train", losses.mean(), embedding.step)
            tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", losses.mean(), epoch_step)
            tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", scheduler.learn_rate, embedding.step)
            tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", scheduler.learn_rate, epoch_step)
            tensorboard_add(tensorboard_writer, loss=losses.mean(), global_step=embedding.step, 
                step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)

        write_loss(log_directory, "textual_inversion_loss.csv", embedding.step, len(ds), {
            "loss": f"{losses.mean():.7f}",
@@ -377,7 +383,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
                embedding_yet_to_be_embedded = False

            image.save(last_saved_image)
            tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)

            if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
                tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", 
                    image, embedding.step)

            last_saved_image += f", prompt: {preview_text}"