Commit bd68e35d authored by flamelaw's avatar flamelaw
Browse files

Gradient accumulation, autocast fix, new latent sampling method, etc

parent 47a44c7e
Loading
Loading
Loading
Loading
+146 −123
Original line number Diff line number Diff line
@@ -367,13 +367,13 @@ def report_statistics(loss_info:dict):



def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")

    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
@@ -403,29 +403,25 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    hypernetwork = shared.loaded_hypernetwork
    checkpoint = sd_models.select_checkpoint()

    ititial_step = hypernetwork.step or 0
    if ititial_step >= steps:
    initial_step = hypernetwork.step or 0
    if initial_step >= steps:
        shared.state.textinfo = f"Model has already been trained beyond specified max steps"
        return hypernetwork, filename

    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)

    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)

    pin_memory = shared.opts.pin_memory

    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, batch_size=ds.batch_size, pin_memory=pin_memory)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
    
    size = len(ds.indexes)
    loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    losses = torch.zeros((size,))
    previous_mean_losses = [0]
    previous_mean_loss = 0
    print("Mean loss of {} elements".format(size))
    
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True
@@ -446,62 +442,81 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            print("Cannot resume from saved optimizer!")
            print(e)

    scaler = torch.cuda.amp.GradScaler()
    
    batch_size = ds.batch_size
    gradient_step = ds.gradient_step
    # n steps = batch_size * gradient_step * n image processed
    steps_per_epoch = len(ds) // batch_size // gradient_step
    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
    loss_step = 0
    _loss_step = 0 #internal
    # size = len(ds.indexes)
    # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    # losses = torch.zeros((size,))
    # previous_mean_losses = [0]
    # previous_mean_loss = 0
    # print("Mean loss of {} elements".format(size))

    steps_without_grad = 0

    last_saved_file = "<none>"
    last_saved_image = "<none>"
    forced_filename = "<none>"

    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
    for i, entries in pbar:
        hypernetwork.step = i + ititial_step
        if len(loss_dict) > 0:
            previous_mean_losses = [i[-1] for i in loss_dict.values()]
            previous_mean_loss = mean(previous_mean_losses)
            
    pbar = tqdm.tqdm(total=steps - initial_step)
    try:
        for i in range((steps-initial_step) * gradient_step):
            if scheduler.finished:
                break
            if shared.state.interrupted:
                break
            for j, batch in enumerate(dl):
                # works as a drop_last=True for gradient accumulation
                if j == max_steps_per_epoch:
                    break
                scheduler.apply(optimizer, hypernetwork.step)
                if scheduler.finished:
                    break

                if shared.state.interrupted:
                    break

                with torch.autocast("cuda"):
            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
            # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                    else:
                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
                    loss = shared.sd_model(x, c)[0] / gradient_step
                    del x
                    del c

            losses[hypernetwork.step % losses.shape[0]] = loss.item()
            for entry in entries:
                loss_dict[entry.filename].append(loss.item())
                
            optimizer.zero_grad()
            weights[0].grad = None
            loss.backward()

            if weights[0].grad is None:
                steps_without_grad += 1
            else:
                steps_without_grad = 0
            assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'

            optimizer.step()
                    _loss_step += loss.item()
                scaler.scale(loss).backward()
                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
                # scaler.unscale_(optimizer)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                scaler.step(optimizer)
                scaler.update()
                hypernetwork.step += 1
                pbar.update()
                optimizer.zero_grad(set_to_none=True)
                loss_step = _loss_step
                _loss_step = 0

                steps_done = hypernetwork.step + 1
                
        if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): 
            raise RuntimeError("Loss diverged.")
        
        if len(previous_mean_losses) > 1:
            std = stdev(previous_mean_losses)
        else:
            std = 0
        dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
        pbar.set_description(dataset_loss_info)
                epoch_num = hypernetwork.step // steps_per_epoch
                epoch_step = hypernetwork.step % steps_per_epoch

                pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
                    # Before saving, change name to match current checkpoint.
                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -512,8 +527,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{previous_mean_loss:.7f}",
                textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
                    "loss": f"{loss_step:.7f}",
                    "learn_rate": scheduler.learn_rate
                })

@@ -521,7 +536,6 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                    forced_filename = f'{hypernetwork_name}-{steps_done}'
                    last_saved_image = os.path.join(images_dir, forced_filename)

            optimizer.zero_grad()
                    shared.sd_model.cond_stage_model.to(devices.device)
                    shared.sd_model.first_stage_model.to(devices.device)

@@ -541,8 +555,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                        p.width = preview_width
                        p.height = preview_height
                    else:
                p.prompt = entries[0].cond_text
                        p.prompt = batch.cond_text[0]
                        p.steps = 20
                        p.width = training_width
                        p.height = training_height

                    preview_text = p.prompt

@@ -562,15 +578,19 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

                shared.state.textinfo = f"""
<p>
Loss: {previous_mean_loss:.7f}<br/>
Loss: {loss_step:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
        
    report_statistics(loss_dict)
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
    finally:
        pbar.leave = False
        pbar.close()
        #report_statistics(loss_dict)

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
    hypernetwork.optimizer_name = optimizer_name
@@ -579,6 +599,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
    del optimizer
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
    shared.sd_model.cond_stage_model.to(devices.device)
    shared.sd_model.first_stage_model.to(devices.device)

    return hypernetwork, filename

def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+7 −2
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu

import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.shared import opts, device, cmd_opts
from modules.sd_hijack_optimizations import invokeAI_mps_available

@@ -59,6 +59,10 @@ def undo_optimizations():
def get_target_prompt_token_count(token_count):
    return math.ceil(max(token_count, 1) / 75) * 75

def fix_checkpoint():
    ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward

class StableDiffusionModelHijack:
    fixes = None
@@ -78,6 +82,7 @@ class StableDiffusionModelHijack:
        self.clip = m.cond_stage_model

        apply_optimizations()
        fix_checkpoint()

        def flatten(el):
            flattened = [flatten(children) for children in el.children()]
+10 −0
Original line number Diff line number Diff line
from torch.utils.checkpoint import checkpoint

def BasicTransformerBlock_forward(self, x, context=None):
    return checkpoint(self._forward, x, context)

def AttentionBlock_forward(self, x):
    return checkpoint(self._forward, x)

def ResBlock_forward(self, x, emb):
    return checkpoint(self._forward, x, emb)
 No newline at end of file
+1 −2
Original line number Diff line number Diff line
@@ -322,8 +322,7 @@ options_templates.update(options_section(('system', "System"), {

options_templates.update(options_section(('training', "Training"), {
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
    "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
    "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+86 −48
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ import numpy as np
import PIL
import torch
from PIL import Image
from torch.utils.data import Dataset
from torch.utils.data import Dataset, DataLoader
from torchvision import transforms

import random
@@ -11,25 +11,28 @@ import tqdm
from modules import devices, shared
import re

from ldm.modules.distributions.distributions import DiagonalGaussianDistribution

re_numbers_at_start = re.compile(r"^[-\d]+\s*")


class DatasetEntry:
    def __init__(self, filename=None, latent=None, filename_text=None):
    def __init__(self, filename=None, filename_text=None, latent_dist=None, latent_sample=None, cond=None, cond_text=None, pixel_values=None):
        self.filename = filename
        self.latent = latent
        self.filename_text = filename_text
        self.cond = None
        self.cond_text = None
        self.latent_dist = latent_dist
        self.latent_sample = latent_sample
        self.cond = cond
        self.cond_text = cond_text
        self.pixel_values = pixel_values


class PersonalizedBase(Dataset):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, cond_model=None, device=None, template_file=None, include_cond=False, batch_size=1, gradient_step=1, shuffle_tags=False, tag_drop_out=0, latent_sampling_method='once'):        
        re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None
        
        self.placeholder_token = placeholder_token

        self.batch_size = batch_size
        self.width = width
        self.height = height
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)
@@ -45,11 +48,16 @@ class PersonalizedBase(Dataset):
        assert os.path.isdir(data_root), "Dataset directory doesn't exist"
        assert os.listdir(data_root), "Dataset directory is empty"

        cond_model = shared.sd_model.cond_stage_model

        self.image_paths = [os.path.join(data_root, file_path) for file_path in os.listdir(data_root)]

        
        self.shuffle_tags = shuffle_tags
        self.tag_drop_out = tag_drop_out

        print("Preparing dataset...")
        for path in tqdm.tqdm(self.image_paths):
            if shared.state.interrupted:
                raise Exception("inturrupted")
            try:
                image = Image.open(path).convert('RGB').resize((self.width, self.height), PIL.Image.BICUBIC)
            except Exception:
@@ -71,37 +79,58 @@ class PersonalizedBase(Dataset):
            npimage = np.array(image).astype(np.uint8)
            npimage = (npimage / 127.5 - 1.0).astype(np.float32)

            torchdata = torch.from_numpy(npimage).to(device=device, dtype=torch.float32)
            torchdata = torch.moveaxis(torchdata, 2, 0)

            init_latent = model.get_first_stage_encoding(model.encode_first_stage(torchdata.unsqueeze(dim=0))).squeeze()
            init_latent = init_latent.to(devices.cpu)

            entry = DatasetEntry(filename=path, filename_text=filename_text, latent=init_latent)

            if include_cond:
            torchdata = torch.from_numpy(npimage).permute(2, 0, 1).to(device=device, dtype=torch.float32)
            latent_sample = None

            with torch.autocast("cuda"):
                latent_dist = model.encode_first_stage(torchdata.unsqueeze(dim=0))

            if latent_sampling_method == "once" or (latent_sampling_method == "deterministic" and not isinstance(latent_dist, DiagonalGaussianDistribution)):
                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
                latent_sampling_method = "once"
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
            elif latent_sampling_method == "deterministic":
                # Works only for DiagonalGaussianDistribution
                latent_dist.std = 0
                latent_sample = model.get_first_stage_encoding(latent_dist).squeeze().to(devices.cpu)
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_sample=latent_sample)
            elif latent_sampling_method == "random":
                entry = DatasetEntry(filename=path, filename_text=filename_text, latent_dist=latent_dist)

            if not (self.tag_drop_out != 0 or self.shuffle_tags):
                entry.cond_text = self.create_text(filename_text)

            if include_cond and not (self.tag_drop_out != 0 or self.shuffle_tags):
                with torch.autocast("cuda"):
                    entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
            # elif not include_cond:
            #     _, _, _, _, hijack_fixes, token_count = cond_model.process_text([entry.cond_text])
            #     max_n = token_count // 75
            #     index_list = [ [] for _ in range(max_n + 1) ]
            #     for n, (z, _) in hijack_fixes[0]:
            #         index_list[n].append(z)
            #     with torch.autocast("cuda"):
            #         entry.cond = cond_model([entry.cond_text]).to(devices.cpu).squeeze(0)
            #     entry.emb_index = index_list

            self.dataset.append(entry)
            del torchdata
            del latent_dist
            del latent_sample

        assert len(self.dataset) > 0, "No images have been found in the dataset."
        self.length = len(self.dataset) * repeats // batch_size

        self.dataset_length = len(self.dataset)
        self.indexes = None
        self.shuffle()

    def shuffle(self):
        self.indexes = np.random.permutation(self.dataset_length)
        self.length = len(self.dataset)
        assert self.length > 0, "No images have been found in the dataset."
        self.batch_size = min(batch_size, self.length)
        self.gradient_step = min(gradient_step, self.length // self.batch_size)
        self.latent_sampling_method = latent_sampling_method

    def create_text(self, filename_text):
        text = random.choice(self.lines)
        text = text.replace("[name]", self.placeholder_token)
        tags = filename_text.split(',')
        if shared.opts.tag_drop_out != 0:
            tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
        if shared.opts.shuffle_tags:
        if self.tag_drop_out != 0:
            tags = [t for t in tags if random.random() > self.tag_drop_out]
        if self.shuffle_tags:
            random.shuffle(tags)
        text = text.replace("[filewords]", ','.join(tags))
        return text
@@ -110,19 +139,28 @@ class PersonalizedBase(Dataset):
        return self.length

    def __getitem__(self, i):
        res = []
        entry = self.dataset[i]
        if self.tag_drop_out != 0 or self.shuffle_tags:
            entry.cond_text = self.create_text(entry.filename_text)
        if self.latent_sampling_method == "random":
            entry.latent_sample = shared.sd_model.get_first_stage_encoding(entry.latent_dist)
        return entry

        for j in range(self.batch_size):
            position = i * self.batch_size + j
            if position % len(self.indexes) == 0:
                self.shuffle()
class PersonalizedDataLoader(DataLoader):
    def __init__(self, *args, **kwargs):
        super(PersonalizedDataLoader, self).__init__(shuffle=True, drop_last=True, *args, **kwargs)
        self.collate_fn = collate_wrapper
        
            index = self.indexes[position % len(self.indexes)]
            entry = self.dataset[index]

            if entry.cond is None:
                entry.cond_text = self.create_text(entry.filename_text)
class BatchLoader:
    def __init__(self, data):
        self.cond_text = [entry.cond_text for entry in data]
        self.cond = [entry.cond for entry in data]
        self.latent_sample = torch.stack([entry.latent_sample for entry in data]).squeeze(1)

            res.append(entry)
    def pin_memory(self):
        self.latent_sample = self.latent_sample.pin_memory()
        return self

        return res
def collate_wrapper(batch):
    return BatchLoader(batch)
 No newline at end of file
Loading