Unverified Commit 9092e1ca authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #3842 from R-N/gradient-clipping

Gradient clipping in train tab
parents b7deea47 eeb1de43
Loading
Loading
Loading
Loading
+14 −9
Original line number Diff line number Diff line
@@ -402,10 +402,8 @@ def create_hypernetwork(name, enable_sizes, overwrite_old, layer_structure=None,

    shared.reload_hypernetworks()

    return fn


def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_hypernetwork_every, template_file, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    # images allows training previews to have infotext. Importing it at the top causes a circular import problem.
    from modules import images

@@ -449,6 +447,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,

    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)
    
    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else None
    if clip_grad:
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, initial_step, verbose=False)

    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."

@@ -525,6 +527,9 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                if shared.state.interrupted:
                    break

                if clip_grad:
                    clip_grad_sched.step(hypernetwork.step)
                
                with devices.autocast():
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
@@ -539,14 +544,14 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,

                    _loss_step += loss.item()
                scaler.scale(loss).backward()
                
                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.7f}")
                # scaler.unscale_(optimizer)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")
                # torch.nn.utils.clip_grad_norm_(weights, max_norm=1.0)
                # print(f"grad:{weights[0].grad.detach().cpu().abs().mean().item():.15f}")

                if clip_grad:
                    clip_grad(weights, clip_grad_sched.learn_rate)
                
                scaler.step(optimizer)
                scaler.update()
                hypernetwork.step += 1
+8 −3
Original line number Diff line number Diff line
@@ -58,14 +58,19 @@ class LearnRateScheduler:

        self.finished = False

    def apply(self, optimizer, step_number):
    def step(self, step_number):
        if step_number < self.end_step:
            return
            return False

        try:
            (self.learn_rate, self.end_step) = next(self.schedules)
        except Exception:
        except StopIteration:
            self.finished = True
            return False
        return True

    def apply(self, optimizer, step_number):
        if not self.step(step_number):
            return

        if self.verbose:
+13 −2
Original line number Diff line number Diff line
@@ -251,8 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
    if save_model_every or create_image_every:
        assert log_directory, "Log directory is empty"


def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_root, log_directory, training_width, training_height, steps, clip_grad_mode, clip_grad_value, shuffle_tags, tag_drop_out, latent_sampling_method, create_image_every, save_embedding_every, template_file, save_image_with_stored_embedding, preview_from_txt2img, preview_prompt, preview_negative_prompt, preview_steps, preview_sampler_index, preview_cfg_scale, preview_seed, preview_width, preview_height):
    save_embedding_every = save_embedding_every or 0
    create_image_every = create_image_every or 0
    validate_train_inputs(embedding_name, learn_rate, batch_size, gradient_step, data_root, template_file, steps, save_embedding_every, create_image_every, log_directory, name="embedding")
@@ -295,6 +294,11 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
        return embedding, filename
    scheduler = LearnRateScheduler(learn_rate, steps, initial_step)

    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
        torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
        None
    if clip_grad:
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
    old_parallel_processing_allowed = shared.parallel_processing_allowed
@@ -361,6 +365,9 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                if shared.state.interrupted:
                    break

                if clip_grad:
                    clip_grad_sched.step(embedding.step)
            
                with devices.autocast():
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    c = shared.sd_model.cond_stage_model(batch.cond_text)
@@ -382,6 +389,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                # go back until we reach gradient accumulation steps
                if (j + 1) % gradient_step != 0:
                    continue
                
                if clip_grad:
                    clip_grad(embedding.vec, clip_grad_sched.learn_rate)

                scaler.step(optimizer)
                scaler.update()
                embedding.step += 1
+8 −0
Original line number Diff line number Diff line
@@ -1291,6 +1291,10 @@ def create_ui():
                        embedding_learn_rate = gr.Textbox(label='Embedding Learning rate', placeholder="Embedding Learning rate", value="0.005", elem_id="train_embedding_learn_rate")
                        hypernetwork_learn_rate = gr.Textbox(label='Hypernetwork Learning rate', placeholder="Hypernetwork Learning rate", value="0.00001", elem_id="train_hypernetwork_learn_rate")
                    
                    with gr.Row():
                        clip_grad_mode = gr.Dropdown(value="disabled", label="Gradient Clipping", choices=["disabled", "value", "norm"])
                        clip_grad_value = gr.Textbox(placeholder="Gradient clip value", value="0.1", show_label=False)

                    batch_size = gr.Number(label='Batch size', value=1, precision=0, elem_id="train_batch_size")
                    gradient_step = gr.Number(label='Gradient accumulation steps', value=1, precision=0, elem_id="train_gradient_step")
                    dataset_directory = gr.Textbox(label='Dataset directory', placeholder="Path to directory with input images", elem_id="train_dataset_directory")
@@ -1402,6 +1406,8 @@ def create_ui():
                training_width,
                training_height,
                steps,
                clip_grad_mode,
                clip_grad_value,
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,
@@ -1431,6 +1437,8 @@ def create_ui():
                training_width,
                training_height,
                steps,
                clip_grad_mode,
                clip_grad_value,
                shuffle_tags,
                tag_drop_out,
                latent_sampling_method,