Commit 184e6701 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix the merge

parent 8839b372
Loading
Loading
Loading
Loading
+5 −9
Original line number Diff line number Diff line
@@ -251,6 +251,7 @@ def validate_train_inputs(model_name, learn_rate, batch_size, gradient_step, dat
    if save_model_every or create_image_every:
        assert log_directory, "Log directory is empty"


def create_dummy_mask(x, width=None, height=None):
    if shared.sd_model.model.conditioning_key in {'hybrid', 'concat'}:

@@ -380,17 +381,12 @@ def train_embedding(embedding_name, learn_rate, batch_size, gradient_step, data_
                    break

                with devices.autocast():
                    # c = stack_conds(batch.cond).to(devices.device)
                    # mask = torch.tensor(batch.emb_index).to(devices.device, non_blocking=pin_memory)
                    # print(mask)
                    # c[:, 1:1+embedding.vec.shape[0]] = embedding.vec.to(devices.device, non_blocking=pin_memory)
                    
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    c = shared.sd_model.cond_stage_model(batch.cond_text)

                    if img_c is None:
                        img_c = create_dummy_mask(c, training_width, training_height)

                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    c = shared.sd_model.cond_stage_model(batch.cond_text)
                    cond = {"c_concat": [img_c], "c_crossattn": [c]}
                    loss = shared.sd_model(x, cond)[0] / gradient_step
                    del x