Commit bb832d77 authored by Muhammad Rizqi Nur's avatar Muhammad Rizqi Nur
Browse files

Simplify grad clip

parent 3277f90e
Loading
Loading
Loading
Loading
+7 −9
Original line number Diff line number Diff line
@@ -385,10 +385,10 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log

    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)
    
    clip_grad_mode_value = clip_grad_mode == "value"
    clip_grad_mode_norm = clip_grad_mode == "norm"
    clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
    if clip_grad_enabled:
    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
        torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
        None
    if clip_grad:
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
    
    # dataset loading may take a while, so input validations and early returns should be done before this
@@ -433,7 +433,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
        if shared.state.interrupted:
            break

        if clip_grad_enabled:
        if clip_grad:
            clip_grad_sched.step(hypernetwork.step)

        with torch.autocast("cuda"):
@@ -458,10 +458,8 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
                steps_without_grad = 0
            assert steps_without_grad < 10, 'no gradient found for the trained weight after backward() for 10 steps in a row; this is a bug; training cannot continue'

            if clip_grad_mode_value:
                torch.nn.utils.clip_grad_value_(weights, clip_value=clip_grad_sched.learn_rate)
            elif clip_grad_mode_norm:
                torch.nn.utils.clip_grad_norm_(weights, max_norm=clip_grad_sched.learn_rate)
            if clip_grad:
                clip_grad(weights, clip_grad_sched.learn_rate)

            optimizer.step()

+7 −9
Original line number Diff line number Diff line
@@ -269,10 +269,10 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc

    scheduler = LearnRateScheduler(learn_rate, steps, ititial_step)

    clip_grad_mode_value = clip_grad_mode == "value"
    clip_grad_mode_norm = clip_grad_mode == "norm"
    clip_grad_enabled = clip_grad_mode_value or clip_grad_mode_norm
    if clip_grad_enabled:
    clip_grad = torch.nn.utils.clip_grad_value_ if clip_grad_mode == "value" else \
        torch.nn.utils.clip_grad_norm_ if clip_grad_mode == "norm" else \
        None
    if clip_grad:
        clip_grad_sched = LearnRateScheduler(clip_grad_value, steps, ititial_step, verbose=False)
    # dataset loading may take a while, so input validations and early returns should be done before this
    shared.state.textinfo = f"Preparing dataset from {html.escape(data_root)}..."
@@ -302,7 +302,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
        if shared.state.interrupted:
            break

        if clip_grad_enabled:
        if clip_grad:
            clip_grad_sched.step(embedding.step)

        with torch.autocast("cuda"):
@@ -316,10 +316,8 @@ def train_embedding(embedding_name, learn_rate, batch_size, data_root, log_direc
            optimizer.zero_grad()
            loss.backward()

            if clip_grad_mode_value:
                torch.nn.utils.clip_grad_value_(embedding.vec, clip_value=clip_grad_sched.learn_rate)
            elif clip_grad_mode_norm:
                torch.nn.utils.clip_grad_norm_(embedding.vec, max_norm=clip_grad_sched.learn_rate)
            if clip_grad:
                clip_grad(embedding.vec, clip_grad_sched.learn_rate)

            optimizer.step()