Commit f7c787eb authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make it possible to use hypernetworks without opt split attention

parent 97bc0b95
Loading
Loading
Loading
Loading
+34 −8
Original line number Diff line number Diff line
@@ -4,7 +4,12 @@ import sys
import traceback

import torch
from modules import devices

from ldm.util import default
from modules import devices, shared
import torch
from torch import einsum
from einops import rearrange, repeat


class HypernetworkModule(torch.nn.Module):
@@ -48,15 +53,36 @@ def load_hypernetworks(path):

    return res

def apply(self, x, context=None, mask=None, original=None):

def attention_CrossAttention_forward(self, x, context=None, mask=None):
    h = self.heads

    q = self.to_q(x)
    context = default(context, x)

    if CrossAttention.hypernetwork is not None and context.shape[2] in CrossAttention.hypernetwork:
        if context.shape[1] == 77 and CrossAttention.noise_cond:
            context = context + (torch.randn_like(context) * 0.1)
        h_k, h_v = CrossAttention.hypernetwork[context.shape[2]]
        k = self.to_k(h_k(context))
        v = self.to_v(h_v(context))
    hypernetwork = shared.selected_hypernetwork()
    hypernetwork_layers = (hypernetwork.layers if hypernetwork is not None else {}).get(context.shape[2], None)

    if hypernetwork_layers is not None:
        k = self.to_k(hypernetwork_layers[0](context))
        v = self.to_v(hypernetwork_layers[1](context))
    else:
        k = self.to_k(context)
        v = self.to_v(context)

    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))

    sim = einsum('b i d, b j d -> b i j', q, k) * self.scale

    if mask is not None:
        mask = rearrange(mask, 'b ... -> b (...)')
        max_neg_value = -torch.finfo(sim.dtype).max
        mask = repeat(mask, 'b j -> (b h) () j', h=h)
        sim.masked_fill_(~mask, max_neg_value)

    # attention, what we cannot get enough of
    attn = sim.softmax(dim=-1)

    out = einsum('b i j, b j d -> b i d', attn, v)
    out = rearrange(out, '(b h) n d -> b n (h d)', h=h)
    return self.to_out(out)
+4 −2
Original line number Diff line number Diff line
@@ -8,7 +8,7 @@ from torch import einsum
from torch.nn.functional import silu

import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, hypernetwork
from modules.shared import opts, device, cmd_opts

import ldm.modules.attention
@@ -20,6 +20,8 @@ diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.At


def apply_optimizations():
    undo_optimizations()

    ldm.modules.diffusionmodules.model.nonlinearity = silu

    if cmd_opts.opt_split_attention_v1:
@@ -30,7 +32,7 @@ def apply_optimizations():


def undo_optimizations():
    ldm.modules.attention.CrossAttention.forward = attention_CrossAttention_forward
    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
    ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward