Commit bc509367 authored by Shondoit's avatar Shondoit
Browse files

Call weighted_forward during training

parent 21642000
Loading
Loading
Loading
Loading
+2 −1
Original line number Diff line number Diff line
@@ -640,13 +640,14 @@ def train_hypernetwork(id_task, hypernetwork_name, learn_rate, batch_size, gradi
                
                with devices.autocast():
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    w = batch.weight.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
                        c = shared.sd_model.cond_stage_model(batch.cond_text).to(devices.device, non_blocking=pin_memory)
                        shared.sd_model.cond_stage_model.to(devices.cpu)
                    else:
                        c = stack_conds(batch.cond).to(devices.device, non_blocking=pin_memory)
                    loss = shared.sd_model(x, c)[0] / gradient_step
                    loss = shared.sd_model.weighted_forward(x, c, w)[0] / gradient_step
                    del x
                    del c

+2 −1
Original line number Diff line number Diff line
@@ -480,6 +480,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
            
                with devices.autocast():
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    w = batch.weight.to(devices.device, non_blocking=pin_memory)
                    c = shared.sd_model.cond_stage_model(batch.cond_text)

                    if is_training_inpainting_model:
@@ -490,7 +491,7 @@ def train_embedding(id_task, embedding_name, learn_rate, batch_size, gradient_st
                    else:
                        cond = c

                    loss = shared.sd_model(x, cond)[0] / gradient_step
                    loss = shared.sd_model.weighted_forward(x, cond, w)[0] / gradient_step
                    del x

                    _loss_step += loss.item()