Unverified Commit 73776907 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #4117 from TinkTheBoush/master

Adding optional tag shuffling for training
parents 6585cba2 a1e27120
Loading
Loading
Loading
Loading
+2 −0
Original line number Original line Diff line number Diff line
@@ -319,6 +319,8 @@ options_templates.update(options_section(('system', "System"), {


options_templates.update(options_section(('training', "Training"), {
options_templates.update(options_section(('training', "Training"), {
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "unload_models_when_training": OptionInfo(False, "Move VAE and CLIP to RAM when training if possible. Saves VRAM."),
    "shuffle_tags": OptionInfo(False, "Shuffleing tags by ',' when create texts."),
    "tag_drop_out": OptionInfo(0, "Dropout tags when create texts", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.1}),
    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
    "save_optimizer_state": OptionInfo(False, "Saves Optimizer state as separate *.optim file. Training can be resumed with HN itself and matching optim file."),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_word_regex": OptionInfo("", "Filename word regex"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
    "dataset_filename_join_string": OptionInfo(" ", "Filename join string"),
+6 −1
Original line number Original line Diff line number Diff line
@@ -98,7 +98,12 @@ class PersonalizedBase(Dataset):
    def create_text(self, filename_text):
    def create_text(self, filename_text):
        text = random.choice(self.lines)
        text = random.choice(self.lines)
        text = text.replace("[name]", self.placeholder_token)
        text = text.replace("[name]", self.placeholder_token)
        text = text.replace("[filewords]", filename_text)
        tags = filename_text.split(',')
        if shared.opts.tag_drop_out != 0:
            tags = [t for t in tags if random.random() > shared.opts.tag_drop_out]
        if shared.opts.shuffle_tags:
            random.shuffle(tags)
        text = text.replace("[filewords]", ','.join(tags))
        return text
        return text


    def __len__(self):
    def __len__(self):