Unverified Commit 1849f6eb authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #3264 from Melanpan/tensorboard

Add support for Tensorboard (training)
parents 544e7a23 9cd77167
Loading
Loading
Loading
Loading
+14 −1
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ from statistics import stdev, mean

optimizer_dict = {optim_name : cls_obj for optim_name, cls_obj in inspect.getmembers(torch.optim, inspect.isclass) if optim_name != "Optimizer"}


class HypernetworkModule(torch.nn.Module):
    multiplier = 1.0
    activation_dict = {
@@ -498,6 +497,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
    if clip_grad:
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)

    if shared.opts.training_enable_tensorboard:
        tensorboard_writer = textual_inversion.tensorboard_setup(log_directory)

    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."

@@ -632,6 +634,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.



                if shared.opts.training_enable_tensorboard:
                    epoch_num = hypernetwork.step // len(ds)
                    epoch_step = hypernetwork.step - (epoch_num * len(ds)) + 1

                    textual_inversion.tensorboard_add(tensorboard_writer, loss=mean_loss, global_step=hypernetwork.step, step=epoch_step, learn_rate=scheduler.learn_rate, epoch_num=epoch_num)

                textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
                    "loss": f"{loss_step:.7f}",
                    "learn_rate": scheduler.learn_rate
@@ -674,6 +684,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                    processed = processing.process_images(p)
                    image = processed.images[0] if len(processed.images) > 0 else None
                    
                    if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
                        textual_inversion.tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, hypernetwork.step)

                    if unload:
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                        shared.sd_model.first_stage_model.to(devices.cpu)
+3 −0
Original line number Diff line number Diff line
@@ -373,6 +373,9 @@ options_templates.update(options_section(('training', "Training"), {
    "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
    "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
    "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
    "training_enable_tensorboard": OptionInfo(False, "Enable tensorboard logging."),
    "training_tensorboard_save_images": OptionInfo(False, "Save generated images within tensorboard."),
    "training_tensorboard_flush_every": OptionInfo(120, "How often, in seconds, to flush the pending tensorboard events and summaries to disk."),
}))

options_templates.update(options_section(('sd', "Stable Diffusion"), {
+31 −0
Original line number Diff line number Diff line
@@ -12,6 +12,7 @@ import csv
import safetensors.torch

from PIL import Image, PngImagePlugin
from torch.utils.tensorboard import SummaryWriter

from modules import shared, devices, sd_hijack, processing, sd_models, images, sd_samplers
import modules.textual_inversion.dataset
@@ -294,6 +295,30 @@ def write_loss(log_directory, filename, step, epoch_len, values):
            **values,
        })

def tensorboard_setup(log_directory):
    os.makedirs(os.path.join(log_directory, "tensorboard"), exist_ok=True)
    return SummaryWriter(
            log_dir=os.path.join(log_directory, "tensorboard"),
            flush_secs=shared.opts.training_tensorboard_flush_every)

def tensorboard_add(tensorboard_writer, loss, global_step, step, learn_rate, epoch_num):
    tensorboard_add_scaler(tensorboard_writer, "Loss/train", loss, global_step)
    tensorboard_add_scaler(tensorboard_writer, f"Loss/train/epoch-{epoch_num}", loss, step)
    tensorboard_add_scaler(tensorboard_writer, "Learn rate/train", learn_rate, global_step)
    tensorboard_add_scaler(tensorboard_writer, f"Learn rate/train/epoch-{epoch_num}", learn_rate, step)

def tensorboard_add_scaler(tensorboard_writer, tag, value, step):
    tensorboard_writer.add_scalar(tag=tag, 
        scalar_value=value, global_step=step)

def tensorboard_add_image(tensorboard_writer, tag, pil_image, step):
    # Convert a pil image to a torch tensor
    img_tensor = torch.as_tensor(np.array(pil_image, copy=True))
    img_tensor = img_tensor.view(pil_image.size[1], pil_image.size[0], 
        len(pil_image.getbands()))
    img_tensor = img_tensor.permute((2, 0, 1))
                
    tensorboard_writer.add_image(tag, img_tensor, global_step=step)

def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, data_root, template_file, template_filename, steps, save_model_every, create_image_every, log_directory, name="embedding"):
    assert model_name, f"{name} not selected"
@@ -373,6 +398,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    old_parallel_processing_allowed = shared.parallel_processing_allowed
    
    if shared.opts.training_enable_tensorboard:
        tensorboard_writer = tensorboard_setup(log_directory)

    pin_memory = shared.opts.pin_memory

    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method, varsize=varsize)
@@ -535,6 +563,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
                        last_saved_image += f", prompt: {preview_text}"

                        if shared.opts.training_enable_tensorboard and shared.opts.training_tensorboard_save_images:
                            tensorboard_add_image(tensorboard_writer, f"Validation at epoch {epoch_num}", image, embedding.step)

                    if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:

                        last_saved_image_chunks = os.path.join(images_embeds_dir, f'{embedding_name}-{steps_done}.png')