Commit 755df94b authored by flamelaw's avatar flamelaw
Browse files

set TI AdamW default weight decay to 0

parent 1bd57cc9
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -283,7 +283,7 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
        shared.sd_model.first_stage_model.to(devices.cpu)

    embedding.vec.requires_grad = True
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate)
    optimizer = torch.optim.AdamW([embedding.vec], lr=scheduler.learn_rate, weight_decay=0.0)
    scaler = torch.cuda.amp.GradScaler()

    batch_size = ds.batch_size