Commit 62e3d71a authored by AUTOMATIC's avatar AUTOMATIC
Browse files

rework the code to not use the walrus operator because colab's 3.7 does not support it

parent b8f2dfed
Loading
Loading
Loading
Loading
+5 −2
Original line number Diff line number Diff line
@@ -429,13 +429,16 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, data_root, log
    weights = hypernetwork.weights()
    for weight in weights:
        weight.requires_grad = True

    # Here we use optimizer from saved HN, or we can specify as UI option.
    if (optimizer_name := hypernetwork.optimizer_name) in optimizer_dict:
    if hypernetwork.optimizer_name in optimizer_dict:
        optimizer = optimizer_dict[hypernetwork.optimizer_name](params=weights, lr=scheduler.learn_rate)
        optimizer_name = hypernetwork.optimizer_name
    else:
        print(f"Optimizer type {optimizer_name} is not defined!")
        print(f"Optimizer type {hypernetwork.optimizer_name} is not defined!")
        optimizer = torch.optim.AdamW(params=weights, lr=scheduler.learn_rate)
        optimizer_name = 'AdamW'

    if hypernetwork.optimizer_state_dict:  # This line must be changed if Optimizer type can be different from saved optimizer.
        try:
            optimizer.load_state_dict(hypernetwork.optimizer_state_dict)