Commit abfa4ad8 authored by brkirch's avatar brkirch
Browse files

Use fixed size for sub-quadratic chunking on MPS

Even if this causes chunks to be much smaller, performance isn't significantly impacted. This will usually reduce memory usage but should also help with poor performance when free memory is low.
parent 3163d126
Loading
Loading
Loading
Loading
+5 −1
Original line number Original line Diff line number Diff line
from __future__ import annotations
from __future__ import annotations
import math
import math
import psutil
import psutil
import platform


import torch
import torch
from torch import einsum
from torch import einsum
@@ -427,7 +428,10 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens
    qk_matmul_size_bytes = batch_x_heads * bytes_per_token * q_tokens * k_tokens


    if chunk_threshold is None:
    if chunk_threshold is None:
        chunk_threshold_bytes = int(get_available_vram() * 0.9) if q.device.type == 'mps' else int(get_available_vram() * 0.7)
        if q.device.type == 'mps':
            chunk_threshold_bytes = 268435456 * (2 if platform.processor() == 'i386' else bytes_per_token)
        else:
            chunk_threshold_bytes = int(get_available_vram() * 0.7)
    elif chunk_threshold == 0:
    elif chunk_threshold == 0:
        chunk_threshold_bytes = None
        chunk_threshold_bytes = None
    else:
    else: