Commit b1198153 authored by brkirch's avatar brkirch
Browse files

Use narrow instead of dynamic_slice

parent 3bfe2bb5
Loading
Loading
Loading
Loading
+19 −15
Original line number Diff line number Diff line
@@ -5,6 +5,7 @@
# credit:
#   Amin Rezaei (original author)
#   Alex Birch (optimized algorithm for 3D tensors, at the expense of removing bias, masking and callbacks)
#   brkirch (modified to use torch.narrow instead of dynamic_slice implementation)
# implementation of:
#   Self-attention Does Not Need O(n2) Memory":
#   https://arxiv.org/abs/2112.05682v2
@@ -16,13 +17,13 @@ from torch.utils.checkpoint import checkpoint
import math
from typing import Optional, NamedTuple, Protocol, List

def dynamic_slice(
    x: Tensor,
    starts: List[int],
    sizes: List[int],
def narrow_trunc(
    input: Tensor,
    dim: int,
    start: int,
    length: int
) -> Tensor:
    slicing = [slice(start, start + size) for start, size in zip(starts, sizes)]
    return x[slicing]
    return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)

class AttnChunk(NamedTuple):
    exp_values: Tensor
@@ -76,15 +77,17 @@ def _query_chunk_attention(
    _, _, v_channels_per_head = value.shape

    def chunk_scanner(chunk_idx: int) -> AttnChunk:
        key_chunk = dynamic_slice(
        key_chunk = narrow_trunc(
            key,
            (0, chunk_idx, 0),
            (batch_x_heads, kv_chunk_size, k_channels_per_head)
            1,
            chunk_idx,
            kv_chunk_size
        )
        value_chunk = dynamic_slice(
        value_chunk = narrow_trunc(
            value,
            (0, chunk_idx, 0),
            (batch_x_heads, kv_chunk_size, v_channels_per_head)
            1,
            chunk_idx,
            kv_chunk_size
        )
        return summarize_chunk(query, key_chunk, value_chunk)

@@ -161,10 +164,11 @@ def efficient_dot_product_attention(
        kv_chunk_size = max(kv_chunk_size, kv_chunk_size_min)

    def get_query_chunk(chunk_idx: int) -> Tensor:
        return dynamic_slice(
        return narrow_trunc(
            query,
            (0, chunk_idx, 0),
            (batch_x_heads, min(query_chunk_size, q_tokens), q_channels_per_head)
            1,
            chunk_idx,
            min(query_chunk_size, q_tokens)
        )
    
    summarize_chunk: SummarizeChunk = partial(_summarize_chunk, scale=scale)