Unverified Commit f0dfed2a authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #5796 from brkirch/invoke-fix

Improve InvokeAI cross attention reliability and speed when using MPS for large images
parents 7115ab5d 35b1775b
Loading
Loading
Loading
Loading
+6 −4
Original line number Diff line number Diff line
@@ -127,7 +127,7 @@ def check_for_psutil():

invokeAI_mps_available = check_for_psutil()

# -- Taken from https://github.com/invoke-ai/InvokeAI --
# -- Taken from https://github.com/invoke-ai/InvokeAI and modified --
if invokeAI_mps_available:
    import psutil
    mem_total_gb = psutil.virtual_memory().total // (1 << 30)
@@ -152,14 +152,16 @@ def einsum_op_slice_1(q, k, v, slice_size):
    return r

def einsum_op_mps_v1(q, k, v):
    if q.shape[1] <= 4096: # (512x512) max q.shape[1]: 4096
    if q.shape[0] * q.shape[1] <= 2**16: # (512x512) max q.shape[1]: 4096
        return einsum_op_compvis(q, k, v)
    else:
        slice_size = math.floor(2**30 / (q.shape[0] * q.shape[1]))
        if slice_size % 4096 == 0:
            slice_size -= 1
        return einsum_op_slice_1(q, k, v, slice_size)

def einsum_op_mps_v2(q, k, v):
    if mem_total_gb > 8 and q.shape[1] <= 4096:
    if mem_total_gb > 8 and q.shape[0] * q.shape[1] <= 2**16:
        return einsum_op_compvis(q, k, v)
    else:
        return einsum_op_slice_0(q, k, v, 1)
@@ -188,7 +190,7 @@ def einsum_op(q, k, v):
        return einsum_op_cuda(q, k, v)

    if q.device.type == 'mps':
        if mem_total_gb >= 32:
        if mem_total_gb >= 32 and q.shape[0] % 32 != 0 and q.shape[0] * q.shape[1] < 2**18:
            return einsum_op_mps_v1(q, k, v)
        return einsum_op_mps_v2(q, k, v)