Commit 2582a0fd authored by AUTOMATIC's avatar AUTOMATIC
Browse files

make it possible for scripts to add cross attention optimizations

add UI selection for cross attention optimization
parent 2e006fa5
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
+21 −0
Original line number Diff line number Diff line
@@ -110,6 +110,7 @@ callback_map = dict(
    callbacks_script_unloaded=[],
    callbacks_before_ui=[],
    callbacks_on_reload=[],
    callbacks_list_optimizers=[],
)


@@ -258,6 +259,18 @@ def before_ui_callback():
            report_exception(c, 'before_ui')


def list_optimizers_callback():
    res = []

    for c in callback_map['callbacks_list_optimizers']:
        try:
            c.callback(res)
        except Exception:
            report_exception(c, 'list_optimizers')

    return res


def add_callback(callbacks, fun):
    stack = [x for x in inspect.stack() if x.filename != __file__]
    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -409,3 +422,11 @@ def on_before_ui(callback):
    """register a function to be called before the UI is created."""

    add_callback(callback_map['callbacks_before_ui'], callback)


def on_list_optimizers(callback):
    """register a function to be called when UI is making a list of cross attention optimization options.
    The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
    to it."""

    add_callback(callback_map['callbacks_list_optimizers'], callback)
+49 −41
Original line number Diff line number Diff line
@@ -3,8 +3,9 @@ from torch.nn.functional import silu
from types import MethodType

import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules.hypernetworks import hypernetwork
from modules.sd_hijack_optimizations import diffusionmodules_model_AttnBlock_forward
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr

@@ -28,57 +29,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None

optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None


def list_optimizers():
    new_optimizers = script_callbacks.list_optimizers_callback()

    new_optimizers = [x for x in new_optimizers if x.is_available()]

    new_optimizers = sorted(new_optimizers, key=lambda x: x.priority(), reverse=True)

    optimizers.clear()
    optimizers.extend(new_optimizers)


def apply_optimizations():
    global current_optimizer

    undo_optimizations()

    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

    optimization_method = None
    if current_optimizer is not None:
        current_optimizer.undo()
        current_optimizer = None

    can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp

    if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
        print("Applying xformers cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
        optimization_method = 'xformers'
    elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
        print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
        optimization_method = 'sdp-no-mem'
    elif cmd_opts.opt_sdp_attention and can_use_sdp:
        print("Applying scaled dot product cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
        optimization_method = 'sdp'
    elif cmd_opts.opt_sub_quad_attention:
        print("Applying sub-quadratic cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
        optimization_method = 'sub-quadratic'
    elif cmd_opts.opt_split_attention_v1:
        print("Applying v1 cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
        optimization_method = 'V1'
    elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
        print("Applying cross attention optimization (InvokeAI).")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
        optimization_method = 'InvokeAI'
    elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
        print("Applying cross attention optimization (Doggettx).")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
        optimization_method = 'Doggettx'

    return optimization_method
    selection = shared.opts.cross_attention_optimization
    if selection == "Automatic" and len(optimizers) > 0:
        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
    else:
        matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)

    if selection == "None":
        matching_optimizer = None
    elif matching_optimizer is None:
        matching_optimizer = optimizers[0]

    if matching_optimizer is not None:
        print(f"Applying optimization: {matching_optimizer.name}")
        matching_optimizer.apply()
        current_optimizer = matching_optimizer
        return current_optimizer.name
    else:
        return ''


def undo_optimizations():
    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
    ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward


@@ -169,7 +169,11 @@ class StableDiffusionModelHijack:
        if m.cond_stage_key == "edit":
            sd_hijack_unet.hijack_ddpm_edit()

        try:
            self.optimization_method = apply_optimizations()
        except Exception as e:
            errors.display(e, "applying cross attention optimization")
            undo_optimizations()

        self.clip = m.cond_stage_model

@@ -223,6 +227,10 @@ class StableDiffusionModelHijack:

        return token_count, self.clip.get_target_prompt_token_count(token_count)

    def redo_hijack(self, m):
        self.undo_hijack(m)
        self.hijack(m)


class EmbeddingsWithFixes(torch.nn.Module):
    def __init__(self, wrapped, embeddings):
+132 −3
Original line number Diff line number Diff line
@@ -9,10 +9,139 @@ from torch import einsum
from ldm.util import default
from einops import rearrange

from modules import shared, errors, devices
from modules import shared, errors, devices, sub_quadratic_attention, script_callbacks
from modules.hypernetworks import hypernetwork

from .sub_quadratic_attention import efficient_dot_product_attention
import ldm.modules.attention
import ldm.modules.diffusionmodules.model

diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward


class SdOptimization:
    def __init__(self, name, label=None, cmd_opt=None):
        self.name = name
        self.label = label
        self.cmd_opt = cmd_opt

    def title(self):
        if self.label is None:
            return self.name

        return f"{self.name} - {self.label}"

    def is_available(self):
        return True

    def priority(self):
        return 0

    def apply(self):
        pass

    def undo(self):
        ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward


class SdOptimizationXformers(SdOptimization):
    def __init__(self):
        super().__init__("xformers", cmd_opt="xformers")

    def is_available(self):
        return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))

    def priority(self):
        return 100

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward


class SdOptimizationSdpNoMem(SdOptimization):
    def __init__(self, name="sdp-no-mem", label="scaled dot product without memory efficient attention", cmd_opt="opt_sdp_no_mem_attention"):
        super().__init__(name, label, cmd_opt)

    def is_available(self):
        return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)

    def priority(self):
        return 90

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward


class SdOptimizationSdp(SdOptimizationSdpNoMem):
    def __init__(self):
        super().__init__("sdp", "scaled dot product", cmd_opt="opt_sdp_attention")

    def priority(self):
        return 80

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward


class SdOptimizationSubQuad(SdOptimization):
    def __init__(self):
        super().__init__("sub-quadratic", cmd_opt="opt_sub_quad_attention")

    def priority(self):
        return 10

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward


class SdOptimizationV1(SdOptimization):
    def __init__(self):
        super().__init__("V1", "original v1", cmd_opt="opt_split_attention_v1")

    def priority(self):
        return 10

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1


class SdOptimizationInvokeAI(SdOptimization):
    def __init__(self):
        super().__init__("InvokeAI", cmd_opt="opt_split_attention_invokeai")

    def priority(self):
        return 1000 if not torch.cuda.is_available() else 10

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI


class SdOptimizationDoggettx(SdOptimization):
    def __init__(self):
        super().__init__("Doggettx", cmd_opt="opt_split_attention")

    def priority(self):
        return 20

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward


def list_optimizers(res):
    res.extend([
        SdOptimizationXformers(),
        SdOptimizationSdpNoMem(),
        SdOptimizationSdp(),
        SdOptimizationSubQuad(),
        SdOptimizationV1(),
        SdOptimizationInvokeAI(),
        SdOptimizationDoggettx(),
    ])


if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +428,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
        kv_chunk_size = k_tokens

    with devices.without_autocast(disable=q.dtype == v.dtype):
        return efficient_dot_product_attention(
        return sub_quadratic_attention.efficient_dot_product_attention(
            q,
            k,
            v,
+1 −0
Original line number Diff line number Diff line
@@ -417,6 +417,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
}))

options_templates.update(options_section(('optimizations', "Optimizations"), {
    "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
    "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
    "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
    "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
Loading