Commit cdfcbd99 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Remove fallback for Protocol import and remove Protocol import and remove...

Remove fallback for Protocol import and remove Protocol import and remove instances of Protocol in code
add some whitespace between functions to be in line with other code in the repo
parent 89c36630
Loading
Loading
Loading
Loading
+11 −8
Original line number Diff line number Diff line
@@ -15,14 +15,9 @@ import torch
from torch import Tensor
from torch.utils.checkpoint import checkpoint
import math

try:
    from typing import Protocol
except:
    from typing_extensions import Protocol
    
from typing import Optional, NamedTuple, List


def narrow_trunc(
    input: Tensor,
    dim: int,
@@ -31,12 +26,14 @@ def narrow_trunc(
) -> Tensor:
    return torch.narrow(input, dim, start, length if input.shape[dim] >= start + length else input.shape[dim] - start)


class AttnChunk(NamedTuple):
    exp_values: Tensor
    exp_weights_sum: Tensor
    max_score: Tensor

class SummarizeChunk(Protocol):

class SummarizeChunk:
    @staticmethod
    def __call__(
        query: Tensor,
@@ -44,7 +41,8 @@ class SummarizeChunk(Protocol):
        value: Tensor,
    ) -> AttnChunk: ...

class ComputeQueryChunkAttn(Protocol):

class ComputeQueryChunkAttn:
    @staticmethod
    def __call__(
        query: Tensor,
@@ -52,6 +50,7 @@ class ComputeQueryChunkAttn(Protocol):
        value: Tensor,
    ) -> Tensor: ...


def _summarize_chunk(
    query: Tensor,
    key: Tensor,
@@ -72,6 +71,7 @@ def _summarize_chunk(
    max_score = max_score.squeeze(-1)
    return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)


def _query_chunk_attention(
    query: Tensor,
    key: Tensor,
@@ -112,6 +112,7 @@ def _query_chunk_attention(
    all_weights = torch.unsqueeze(chunk_weights, -1).sum(dim=0)
    return all_values / all_weights


# TODO: refactor CrossAttention#get_attention_scores to share code with this
def _get_attention_scores_no_kv_chunking(
    query: Tensor,
@@ -131,10 +132,12 @@ def _get_attention_scores_no_kv_chunking(
    hidden_states_slice = torch.bmm(attn_probs, value)
    return hidden_states_slice


class ScannedChunk(NamedTuple):
    chunk_idx: int
    attn_chunk: AttnChunk


def efficient_dot_product_attention(
    query: Tensor,
    key: Tensor,