Unverified Commit c9e5b921 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #10266 from nero-dv/dev

Update sub_quadratic_attention.py
parents 8aa87c56 c8732dfa
Loading
Loading
Loading
Loading
+15 −6
Original line number Diff line number Diff line
@@ -202,13 +202,22 @@ def efficient_dot_product_attention(
            value=value,
        )
    
    # TODO: maybe we should use torch.empty_like(query) to allocate storage in-advance,
    # and pass slices to be mutated, instead of torch.cat()ing the returned slices
    res = torch.cat([
        compute_query_chunk_attn(
    # slices of res tensor are mutable, modifications made
    # to the slices will affect the original tensor.
    # if output of compute_query_chunk_attn function has same number of
    # dimensions as input query tensor, we initialize tensor like this:
    num_query_chunks = int(np.ceil(q_tokens / query_chunk_size))
    query_shape = get_query_chunk(0).shape
    res_shape = (query_shape[0], query_shape[1] * num_query_chunks, *query_shape[2:])
    res_dtype = get_query_chunk(0).dtype
    res = torch.zeros(res_shape, dtype=res_dtype)

    for i in range(num_query_chunks):
        attn_scores = compute_query_chunk_attn(
            query=get_query_chunk(i * query_chunk_size),
            key=key,
            value=value,
        ) for i in range(math.ceil(q_tokens / query_chunk_size))
    ], dim=1)
        )
        res[:, i * query_chunk_size:(i + 1) * query_chunk_size, :] = attn_scores

    return res