Commit 467cae16 authored by TinkTheBoush's avatar TinkTheBoush
Browse files

append_tag_shuffle

parent c28de154
Loading
Loading
Loading
Loading
+2 −2
Original line number Diff line number Diff line
@@ -331,7 +331,7 @@ def report_statistics(loss_info:dict):



def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

@@ -376,7 +376,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=hypernetwork_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, include_cond=True, batch_size=batch_size)

    if unload:
        shared.sd_model.cond_stage_model.to(devices.cpu)
+8 −2
Original line number Diff line number Diff line
@@ -24,7 +24,7 @@ class DatasetEntry:


class PersonalizedBase(Dataset):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", model=None, device=None, template_file=None, include_cond=False, batch_size=1):
    def __init__(self, data_root, width, height, repeats, flip_p=0.5, placeholder_token="*", shuffle_tags=True, model=None, device=None, template_file=None, include_cond=False, batch_size=1):
        re_word = re.compile(shared.opts.dataset_filename_word_regex) if len(shared.opts.dataset_filename_word_regex) > 0 else None

        self.placeholder_token = placeholder_token
@@ -33,6 +33,7 @@ class PersonalizedBase(Dataset):
        self.width = width
        self.height = height
        self.flip = transforms.RandomHorizontalFlip(p=flip_p)
        self.shuffle_tags = shuffle_tags

        self.dataset = []

@@ -98,6 +99,11 @@ class PersonalizedBase(Dataset):
    def create_text(self, filename_text):
        text = random.choice(self.lines)
        text = text.replace("[name]", self.placeholder_token)
        if self.tag_shuffle:
            tags = filename_text.split(',')
            random.shuffle(tags)
            text = text.replace("[filewords]", ','.join(tags))
        else:
            text = text.replace("[filewords]", filename_text)
        return text

+2 −2
Original line number Diff line number Diff line
@@ -224,7 +224,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, data_root, templat
    if save_model_every or create_image_every:
        assert log_directory, "Log directory is empty"

def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_directory, training_width, training_height, steps, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, shuffle_tags, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    save_embedding_every = save_embedding_every or 0
    create_image_every = create_image_every or 0
    validate_train_inputs(embedding_name, learn_rate, batch_size, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
@@ -271,7 +271,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    with torch.autocast("cuda"):
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)
        ds = modules.textual_inversion.dataset.PersonalizedBase(data_root=data_root, width=training_width, height=training_height, repeats=shared.opts.training_image_repeats_per_epoch, placeholder_token=embedding_name, shuffle_tags=shuffle_tags, model=shared.sd_model, device=devices.device, template_file=template_file, batch_size=batch_size)

    embedding.vec.requires_grad = True
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
+3 −0
Original line number Diff line number Diff line
@@ -1267,6 +1267,7 @@ def create_ui(wrap_gradio_gpu_call):
                    save_embedding_every = gr.Number(label='Save a copy of embedding to log directory every N steps, 0 to disable', value=500, precision=0)
                    save_image_with_stored_embedding = gr.Checkbox(label='Save images with embedding in PNG chunks', value=True)
                    preview_from_txt2img = gr.Checkbox(label='Read parameters (prompt, etc...) from txt2img tab when making previews', value=False)
                    shuffle_tags = gr.Checkbox(label='Shuffleing tags by "," when create texts', value=True)

                    with gr.Row():
                        interrupt_training = gr.Button(value="Interrupt")
@@ -1361,6 +1362,7 @@ def create_ui(wrap_gradio_gpu_call):
                template_file,
                save_image_with_stored_embedding,
                preview_from_txt2img,
                shuffle_tags,
                *txt2img_preview_params,
            ],
            outputs=[
@@ -1385,6 +1387,7 @@ def create_ui(wrap_gradio_gpu_call):
                save_embedding_every,
                template_file,
                preview_from_txt2img,
                shuffle_tags,
                *txt2img_preview_params,
            ],
            outputs=[