Commit 006756f9 authored by Fampai's avatar Fampai
Browse files

Added TI training optimizations

option to use xattention optimizations when training
option to unload vae when training
parent 700162a6
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -256,11 +256,12 @@ options_templates.update(options_section(('system', "System"), {
}))

options_templates.update(options_section(('training', "Training"), {
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
    "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
    "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
    "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
}))

options_templates.update(options_section(('sd', "Stable Diffusion"), {
+9 −0
Original line number Diff line number Diff line
@@ -214,6 +214,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')

    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
    unload = shared.opts.unload_models_when_training

    if save_embedding_every > 0:
        embedding_dir = os.path.join(log_directory, "embeddings")
@@ -238,6 +239,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
    if unload:
        shared.sd_model.first_stage_model.to(devices.cpu)

    hijack = sd_hijack.model_hijack

@@ -303,6 +306,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
        if images_dir is not None and steps_done % create_image_every == 0:
            forced_filename = f'{embedding_name}-{steps_done}'
            last_saved_image = os.path.join(images_dir, forced_filename)

            shared.sd_model.first_stage_model.to(devices.device)

            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                do_not_save_grid=True,
@@ -330,6 +336,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
            processed = processing.process_images(p)
            image = processed.images[0]

            if unload:
                shared.sd_model.first_stage_model.to(devices.cpu)

            shared.state.current_image = image

            if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+5 −2
Original line number Diff line number Diff line
@@ -25,7 +25,9 @@ def train_embedding(*args):

    assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'

    apply_optimizations = shared.opts.training_xattention_optimizations
    try:
        if not apply_optimizations:
            sd_hijack.undo_optimizations()

        embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
    except Exception:
        raise
    finally:
        if not apply_optimizations:
            sd_hijack.apply_optimizations()