Commit b48b7999 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Merge remote-tracking branch 'flamelaw/master'

parents b0063827 755df94b
Loading
Loading
Loading
Loading
+166 −128
Original line number Diff line number Diff line
@@ -38,7 +38,7 @@ class HypernetworkModule(torch.nn.Module):
    activation_dict.update({cls_name.lower(): cls_obj for cls_name, cls_obj in inspect.getmembers(torch.nn.modules.activation) if inspect.isclass(cls_obj) and cls_obj.__module__ == 'torch.nn.modules.activation'})

    def __init__(self, dim, state_dict=None, layer_structure=None, activation_func=None, weight_init='Normal',
                 add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=True):
                 add_layer_norm=False, use_dropout=False, activate_output=False, last_layer_dropout=False):
        super().__init__()

        assert layer_structure is not None, "layer_structure must not be None"
@@ -154,16 +154,28 @@ class Hypernetwork:
                HypernetworkModule(size, None, self.layer_structure, self.activation_func, self.weight_init,
                                   self.add_layer_norm, self.use_dropout, self.activate_output, last_layer_dropout=self.last_layer_dropout),
            )
        self.eval_mode()

    def weights(self):
        res = []
        for k, layers in self.layers.items():
            for layer in layers:
                res += layer.parameters()
        return res

    def train_mode(self):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.train()
                res += layer.trainables()
                for param in layer.parameters():
                    param.requires_grad = True

        return res
    def eval_mode(self):
        for k, layers in self.layers.items():
            for layer in layers:
                layer.eval()
                for param in layer.parameters():
                    param.requires_grad = False

    def save(self, filename):
        state_dict = {}
@@ -367,13 +379,13 @@ def report_statistics(loss_info:dict):



def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

    save_hypernetwork_every = save_hypernetwork_every or 0
    create_image_every = create_image_every or 0
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")
    textual_inversion.validate_train_inputs(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_hypernetwork_every, create_image_every, log_directory, name="hypernetwork")

    path = shared.hypernetworks.get(hypernetwork_name, None)
    shared.loaded_hypernetwork = Hypernetwork()
@@ -403,32 +415,30 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    hypernetwork = shared.loaded_hypernetwork
    checkpoint = sd_models.select_checkpoint()

    ititial_step = hypernetwork.step or 0
    if ititial_step >= steps:
    initial_step = hypernetwork.step or 0
    if initial_step >= steps:
        shared.state.textinfo = f"Model has already been trained beyond specified max steps"
        return hypernetwork, filename

    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)

    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)

    pin_memory = shared.opts.pin_memory

    ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, cond_model=shared.sd_model.cond_stage_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size, gradient_step=gradient_step, shuffle_tags=shuffle_tags, tag_drop_out=tag_drop_out, latent_sampling_method=latent_sampling_method)
    
    latent_sampling_method = ds.latent_sampling_method

    dl = modules.textual_inversion.dataset.PersonalizedDataLoader(ds, latent_sampling_method=latent_sampling_method, batch_size=ds.batch_size, pin_memory=pin_memory)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
        shared.sd_model.first_stage_model.to(devices.cpu)
    
    size = len(ds.indexes)
    loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    losses = torch.zeros((size,))
    previous_mean_losses = [0]
    previous_mean_loss = 0
    print("Mean loss of {} elements".format(size))
    
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True
    hypernetwork.train_mode()

    # Here we use optimizer from saved HN, or we can specify as UI option.
    if hypernetwork.optimizer_name in optimizer_dict:
@@ -446,62 +456,81 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
            print("Cannot resume from saved optimizer!")
            print(e)

    scaler = torch.cuda.amp.GradScaler()
    
    batch_size = ds.batch_size
    gradient_step = ds.gradient_step
    # n steps = batch_size * gradient_step * n image processed
    steps_per_epoch = len(ds) // batch_size // gradient_step
    max_steps_per_epoch = len(ds) // batch_size - (len(ds) // batch_size) % gradient_step
    loss_step = 0
    _loss_step = 0 #internal
    # size = len(ds.indexes)
    # loss_dict = defaultdict(lambda : deque(maxlen = 1024))
    # losses = torch.zeros((size,))
    # previous_mean_losses = [0]
    # previous_mean_loss = 0
    # print("Mean loss of {} elements".format(size))

    steps_without_grad = 0

    last_saved_file = "<none>"
    last_saved_image = "<none>"
    forced_filename = "<none>"

    pbar = tqdm.tqdm(enumerate(ds), total=steps - ititial_step)
    for i, entries in pbar:
        hypernetwork.step = i + ititial_step
        if len(loss_dict) > 0:
            previous_mean_losses = [i[-1] for i in loss_dict.values()]
            previous_mean_loss = mean(previous_mean_losses)
            
    pbar = tqdm.tqdm(total=steps - initial_step)
    try:
        for i in range((steps-initial_step) * gradient_step):
            if scheduler.finished:
                break
            if shared.state.interrupted:
                break
            for j, batch in enumerate(dl):
                # works as a drop_last=True for gradient accumulation
                if j == max_steps_per_epoch:
                    break
                scheduler.apply(optimizer, hypernetwork.step)
                if scheduler.finished:
                    break

                if shared.state.interrupted:
                    break

                with torch.autocast("cuda"):
            c = stack_conds([entry.cond for entry in entries]).to(devices.device)
            # c = torch.vstack([entry.cond for entry in entries]).to(devices.device)
            x = torch.stack([entry.latent for entry in entries]).to(devices.device)
            loss = shared.sd_model(x, c)[0]
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                    else:
                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
                    loss = shared.sd_model(x, c)[0] / gradient_step
                    del x
                    del c

            losses[hypernetwork.step % losses.shape[0]] = loss.item()
            for entry in entries:
                loss_dict[entry.filename].append(loss.item())
                
            optimizer.zero_grad()
            weights[0].grad = None
            loss.backward()

            if weights[0].grad is None:
                steps_without_grad += 1
            else:
                steps_without_grad = 0
            assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'

            optimizer.step()
                    _loss_step += loss.item()
                scaler.scale(loss).backward()
                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
                # scaler.unscale_(optimizer)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                scaler.step(optimizer)
                scaler.update()
                hypernetwork.step += 1
                pbar.update()
                optimizer.zero_grad(set_to_none=True)
                loss_step = _loss_step
                _loss_step = 0

                steps_done = hypernetwork.step + 1
                
        if torch.isnan(losses[hypernetwork.step % losses.shape[0]]): 
            raise RuntimeError("Loss diverged.")
        
        if len(previous_mean_losses) > 1:
            std = stdev(previous_mean_losses)
        else:
            std = 0
        dataset_loss_info = f"dataset loss:{mean(previous_mean_losses):.3f}" + u"\u00B1" + f"({std / (len(previous_mean_losses) ** 0.5):.3f})"
        pbar.set_description(dataset_loss_info)
                epoch_num = hypernetwork.step // steps_per_epoch
                epoch_step = hypernetwork.step % steps_per_epoch

                pbar.set_description(f"[Epoch {epoch_num}: {epoch_step+1}/{steps_per_epoch}]loss: {loss_step:.7f}")
                if hypernetwork_dir is not None and steps_done % save_hypernetwork_every == 0:
                    # Before saving, change name to match current checkpoint.
                    hypernetwork_name_every = f'{hypernetwork_name}-{steps_done}'
@@ -512,16 +541,15 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, last_saved_file)
                    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.

        textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, len(ds), {
            "loss": f"{previous_mean_loss:.7f}",
                textual_inversion.write_loss(log_directory, "hypernetwork_loss.csv", hypernetwork.step, steps_per_epoch, {
                    "loss": f"{loss_step:.7f}",
                    "learn_rate": scheduler.learn_rate
                })

                if images_dir is not None and steps_done % create_image_every == 0:
                    forced_filename = f'{hypernetwork_name}-{steps_done}'
                    last_saved_image = os.path.join(images_dir, forced_filename)

            optimizer.zero_grad()
                    hypernetwork.eval_mode()
                    shared.sd_model.cond_stage_model.to(devices.device)
                    shared.sd_model.first_stage_model.to(devices.device)

@@ -541,8 +569,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                        p.width = preview_width
                        p.height = preview_height
                    else:
                p.prompt = entries[0].cond_text
                        p.prompt = batch.cond_text[0]
                        p.steps = 20
                        p.width = training_width
                        p.height = training_height

                    preview_text = p.prompt

@@ -552,7 +582,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                    if unload:
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                        shared.sd_model.first_stage_model.to(devices.cpu)

                    hypernetwork.train_mode()
                    if image is not None:
                        shared.state.current_image = image
                        last_saved_image, last_text_info = images.save_image(image, images_dir, "", p.seed, p.prompt, shared.opts.samples_format, processed.infotexts[0], p=p, forced_filename=forced_filename, save_to_dirs=False)
@@ -562,15 +592,20 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

                shared.state.textinfo = f"""
<p>
Loss: {previous_mean_loss:.7f}<br/>
Step: {hypernetwork.step}<br/>
Last prompt: {html.escape(entries[0].cond_text)}<br/>
Loss: {loss_step:.7f}<br/>
Step: {steps_done}<br/>
Last prompt: {html.escape(batch.cond_text[0])}<br/>
Last saved hypernetwork: {html.escape(last_saved_file)}<br/>
Last saved image: {html.escape(last_saved_image)}<br/>
</p>
"""
        
    report_statistics(loss_dict)
    except Exception:
        print(traceback.format_exc(), file=sys.stderr)
    finally:
        pbar.leave = False
        pbar.close()
        hypernetwork.eval_mode()
        #report_statistics(loss_dict)

    filename = os.path.join(shared.cmd_opts.hypernetwork_dir, f'{hypernetwork_name}.pt')
    hypernetwork.optimizer_name = optimizer_name
@@ -579,6 +614,9 @@ Last saved image: {html.escape(last_saved_image)}<br/>
    save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename)
    del optimizer
    hypernetwork.optimizer_state_dict = None  # dereference it after saving, to save memory.
    shared.sd_model.cond_stage_model.to(devices.device)
    shared.sd_model.first_stage_model.to(devices.device)

    return hypernetwork, filename

def save_hypernetwork(hypernetwork, checkpoint, hypernetwork_name, filename):
+7 −2
Original line number Diff line number Diff line
@@ -8,9 +8,9 @@ from torch import einsum
from torch.nn.functional import silu

import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules.shared import opts, device, cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip

from modules.sd_hijack_optimizations import invokeAI_mps_available
@@ -66,6 +66,10 @@ def undo_optimizations():
    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward


def fix_checkpoint():
    ldm.modules.attention.BasicTransformerBlock.forward = sd_hijack_checkpoint.BasicTransformerBlock_forward
    ldm.modules.diffusionmodules.openaimodel.ResBlock.forward = sd_hijack_checkpoint.ResBlock_forward
    ldm.modules.diffusionmodules.openaimodel.AttentionBlock.forward = sd_hijack_checkpoint.AttentionBlock_forward

class StableDiffusionModelHijack:
    fixes = None
@@ -88,6 +92,7 @@ class StableDiffusionModelHijack:
        self.clip = m.cond_stage_model

        apply_optimizations()
        fix_checkpoint()

        def flatten(el):
            flattened = [flatten(children) for children in el.children()]
+10 −0
Original line number Diff line number Diff line
from torch.utils.checkpoint import checkpoint

def BasicTransformerBlock_forward(self, x, context=None):
    return checkpoint(self._forward, x, context)

def AttentionBlock_forward(self, x):
    return checkpoint(self._forward, x)

def ResBlock_forward(self, x, emb):
    return checkpoint(self._forward, x, emb)
 No newline at end of file
+1 −2
Original line number Diff line number Diff line
@@ -345,8 +345,7 @@ options_templates.update(options_section(('system', "System"), {

options_templates.update(options_section(('training', "Training"), {
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
    "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
    "pin_memory": OptionInfo(False, "Turn on pin_memory for DataLoader. Makes training slightly faster but can increase memory usage."),
    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+92 −48

File changed.

Preview size limit exceeded, changes collapsed.

Loading