Commit 59146621 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

better support for xformers flash attention on older versions of torch

parent 3fa48207
Loading
Loading
Loading
Loading
+12 −0
Original line number Diff line number Diff line
@@ -24,6 +24,18 @@ See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable
        """)


already_displayed = {}


def display_once(e: Exception, task):
    if task in already_displayed:
        return

    display(e, task)

    already_displayed[task] = 1


def run(code, task):
    try:
        code()
+18 −24
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default
from einops import rearrange

from modules import shared
from modules import shared, errors
from modules.hypernetworks import hypernetwork

from .sub_quadratic_attention import efficient_dot_product_attention
@@ -279,6 +279,21 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
    )


def get_xformers_flash_attention_op(q, k, v):
    if not shared.cmd_opts.xformers_flash_attention:
        return None

    try:
        flash_attention_op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
        fw, bw = flash_attention_op
        if fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
            return flash_attention_op
    except Exception as e:
        errors.display_once(e, "enabling flash attention")

    return None


def xformers_attention_forward(self, x, context=None, mask=None):
    h = self.heads
    q_in = self.to_q(x)
@@ -291,18 +306,7 @@ def xformers_attention_forward(self, x, context=None, mask=None):
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
    del q_in, k_in, v_in

    if shared.cmd_opts.xformers_flash_attention:
        op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
        fw, bw = op
        if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v, attn_bias=None)):
            # print('xformers_attention_forward', q.shape, k.shape, v.shape)
            # Flash Attention is not availabe for the input arguments. 
            # Fallback to default xFormers' backend.
            op = None
    else:
        op = None
        
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=op)
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))

    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
    return self.to_out(out)
@@ -377,17 +381,7 @@ def xformers_attnblock_forward(self, x):
        q = q.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        if shared.cmd_opts.xformers_flash_attention:
            op = xformers.ops.MemoryEfficientAttentionFlashAttentionOp
            fw, bw = op
            if not fw.supports(xformers.ops.fmha.Inputs(query=q, key=k, value=v)):
                # print('xformers_attnblock_forward', q.shape, k.shape, v.shape)
                # Flash Attention is not availabe for the input arguments. 
                # Fallback to default xFormers' backend.
                op = None
        else:
            op = None
        out = xformers.ops.memory_efficient_attention(q, k, v, op=op)
        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
        out = rearrange(out, 'b (h w) c -> b c h w', h=h)
        out = self.proj_out(out)
        return x + out