Commit 006756f9 authored by Fampai's avatar Fampai
Browse files

Added TI training optimizations

option to use xattention optimizations when training
option to unload vae when training
parent 700162a6
Loading
Loading
Loading
Loading
+2 −1
Original line number Original line Diff line number Diff line
@@ -256,11 +256,12 @@ options_templates.update(options_section(('system', "System"), {
}))
}))


options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('training', "Training"), {
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training hypernetwork. Saves VRAM."),
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
    "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
    "training_image_repeats_per_epoch": OptionInfo(1, "Number of repeats for a single input image per epoch; used only for displaying epoch number", gr.Number, {"precision": 0}),
    "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
    "training_write_csv_every": OptionInfo(500, "Save an csv containing the loss to log directory every N steps, 0 to disable"),
    "training_xattention_optimizations": OptionInfo(False, "Use cross attention optimizations while training"),
}))
}))


options_templates.update(options_section(('sd', "Stable Diffusion"), {
options_templates.update(options_section(('sd', "Stable Diffusion"), {
+9 −0
Original line number Original line Diff line number Diff line
@@ -214,6 +214,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')
    filename = os.path.join(shared.cmd_opts.embeddings_dir, f'{embedding_name}.pt')


    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
    log_directory = os.path.join(log_directory, datetime.datetime.now().strftime("%Y-%m-%d"), embedding_name)
    unload = shared.opts.unload_models_when_training


    if save_embedding_every > 0:
    if save_embedding_every > 0:
        embedding_dir = os.path.join(log_directory, "embeddings")
        embedding_dir = os.path.join(log_directory, "embeddings")
@@ -238,6 +239,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
    if unload:
        shared.sd_model.first_stage_model.to(devices.cpu)


    hijack = sd_hijack.model_hijack
    hijack = sd_hijack.model_hijack


@@ -303,6 +306,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
        if images_dir is not None and steps_done % create_image_every == 0:
        if images_dir is not None and steps_done % create_image_every == 0:
            forced_filename = f'{embedding_name}-{steps_done}'
            forced_filename = f'{embedding_name}-{steps_done}'
            last_saved_image = os.path.join(images_dir, forced_filename)
            last_saved_image = os.path.join(images_dir, forced_filename)

            shared.sd_model.first_stage_model.to(devices.device)

            p = processing.StableDiffusionProcessingTxt2Img(
            p = processing.StableDiffusionProcessingTxt2Img(
                sd_model=shared.sd_model,
                sd_model=shared.sd_model,
                do_not_save_grid=True,
                do_not_save_grid=True,
@@ -330,6 +336,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
            processed = processing.process_images(p)
            processed = processing.process_images(p)
            image = processed.images[0]
            image = processed.images[0]


            if unload:
                shared.sd_model.first_stage_model.to(devices.cpu)

            shared.state.current_image = image
            shared.state.current_image = image


            if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
            if save_image_with_stored_embedding and os.path.exists(last_saved_file) and embedding_yet_to_be_embedded:
+5 −2
Original line number Original line Diff line number Diff line
@@ -25,7 +25,9 @@ def train_embedding(*args):


    assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'
    assert not shared.cmd_opts.lowvram, 'Training models with lowvram not possible'


    apply_optimizations = shared.opts.training_xattention_optimizations
    try:
    try:
        if not apply_optimizations:
            sd_hijack.undo_optimizations()
            sd_hijack.undo_optimizations()


        embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
        embedding, filename = modules.textual_inversion.textual_inversion.train_embedding(*args)
@@ -38,5 +40,6 @@ Embedding saved to {html.escape(filename)}
    except Exception:
    except Exception:
        raise
        raise
    finally:
    finally:
        if not apply_optimizations:
            sd_hijack.apply_optimizations()
            sd_hijack.apply_optimizations()