Commit d52a80f7 authored by Shondoit's avatar Shondoit
Browse files

Allow creation of zero vectors for TI

parent 0b8911d8
Loading
Loading
Loading
Loading
+6 −3
Original line number Diff line number Diff line
@@ -248,9 +248,12 @@ def create_embedding(name, num_vectors_per_token, overwrite_old, init_text='*'):
    with devices.autocast():
        cond_model([""])  # will send cond model to GPU if lowvram/medvram is active

    embedded = cond_model.encode_embedding_init_text(init_text, num_vectors_per_token)
    #cond_model expects at least some text, so we provide '*' as backup.
    embedded = cond_model.encode_embedding_init_text(init_text or '*', num_vectors_per_token)
    vec = torch.zeros((num_vectors_per_token, embedded.shape[1]), device=devices.device)

    #Only copy if we provided an init_text, otherwise keep vectors as zeros
    if init_text:
        for i in range(num_vectors_per_token):
            vec[i] = embedded[i * int(embedded.shape[0]) // num_vectors_per_token]