Commit 05e6fc9a authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Merge branch 'ui-selection-for-cross-attention-optimization' into dev

parents cc6c0fc7 2140bd1c
Loading
Loading
Loading
Loading
+7 −7
Original line number Diff line number Diff line
@@ -53,16 +53,16 @@ parser.add_argument("--xformers", action='store_true', help="enable xformers for
parser.add_argument("--force-enable-xformers", action='store_true', help="enable xformers for cross attention layers regardless of whether the checking code thinks you can run it; do not make bug reports if this fails to work")
parser.add_argument("--xformers-flash-attention", action='store_true', help="enable xformers with Flash Attention to improve reproducibility (supported for SD2.x or variant only)")
parser.add_argument("--deepdanbooru", action='store_true', help="does not do anything")
parser.add_argument("--opt-split-attention", action='store_true', help="force-enables Doggettx's cross-attention layer optimization. By default, it's on for torch cuda.")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="enable memory efficient sub-quadratic cross-attention layer optimization")
parser.add_argument("--opt-split-attention", action='store_true', help="prefer Doggettx's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-sub-quad-attention", action='store_true', help="prefer memory efficient sub-quadratic cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--sub-quad-q-chunk-size", type=int, help="query chunk size for the sub-quadratic cross-attention layer optimization to use", default=1024)
parser.add_argument("--sub-quad-kv-chunk-size", type=int, help="kv chunk size for the sub-quadratic cross-attention layer optimization to use", default=None)
parser.add_argument("--sub-quad-chunk-threshold", type=int, help="the percentage of VRAM threshold for the sub-quadratic cross-attention layer optimization to use chunking", default=None)
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="force-enables InvokeAI's cross-attention layer optimization. By default, it's on when cuda is unavailable.")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="enable older version of split attention optimization that does not consume all the VRAM it can find")
parser.add_argument("--opt-sdp-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="enable scaled dot product cross-attention layer optimization without memory efficient attention, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="force-disables cross-attention layer optimization")
parser.add_argument("--opt-split-attention-invokeai", action='store_true', help="prefer InvokeAI's cross-attention layer optimization for automatic choice of optimization")
parser.add_argument("--opt-split-attention-v1", action='store_true', help="prefer older version of split attention optimization for automatic choice of optimization")
parser.add_argument("--opt-sdp-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization for automatic choice of optimization; requires PyTorch 2.*")
parser.add_argument("--opt-sdp-no-mem-attention", action='store_true', help="prefer scaled dot product cross-attention layer optimization without memory efficient attention for automatic choice of optimization, makes image generation deterministic; requires PyTorch 2.*")
parser.add_argument("--disable-opt-split-attention", action='store_true', help="does not do anything")
parser.add_argument("--disable-nan-check", action='store_true', help="do not check if produced images/latent spaces have nans; useful for running without a checkpoint in CI")
parser.add_argument("--use-cpu", nargs='+', help="use CPU as torch device for specified modules", default=[], type=str.lower)
parser.add_argument("--listen", action='store_true', help="launch gradio with 0.0.0.0 as server name, allowing to respond to network requests")
+21 −0
Original line number Diff line number Diff line
@@ -110,6 +110,7 @@ callback_map = dict(
    callbacks_script_unloaded=[],
    callbacks_before_ui=[],
    callbacks_on_reload=[],
    callbacks_list_optimizers=[],
)


@@ -258,6 +259,18 @@ def before_ui_callback():
            report_exception(c, 'before_ui')


def list_optimizers_callback():
    res = []

    for c in callback_map['callbacks_list_optimizers']:
        try:
            c.callback(res)
        except Exception:
            report_exception(c, 'list_optimizers')

    return res


def add_callback(callbacks, fun):
    stack = [x for x in inspect.stack() if x.filename != __file__]
    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -409,3 +422,11 @@ def on_before_ui(callback):
    """register a function to be called before the UI is created."""

    add_callback(callback_map['callbacks_before_ui'], callback)


def on_list_optimizers(callback):
    """register a function to be called when UI is making a list of cross attention optimization options.
    The function will be called with one argument, a list, and shall add objects of type modules.sd_hijack_optimizations.SdOptimization
    to it."""

    add_callback(callback_map['callbacks_list_optimizers'], callback)
+48 −41
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType

import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -28,57 +28,56 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None

optimizers = []
current_optimizer: sd_hijack_optimizations.SdOptimization = None


def list_optimizers():
    new_optimizers = script_callbacks.list_optimizers_callback()

    new_optimizers = [x for x in new_optimizers if x.is_available()]

    new_optimizers = sorted(new_optimizers, key=lambda x: x.priority, reverse=True)

    optimizers.clear()
    optimizers.extend(new_optimizers)


def apply_optimizations():
    global current_optimizer

    undo_optimizations()

    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th

    optimization_method = None
    if current_optimizer is not None:
        current_optimizer.undo()
        current_optimizer = None

    can_use_sdp = hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention) # not everyone has torch 2.x to use sdp

    if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
        print("Applying xformers cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.xformers_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.xformers_attnblock_forward
        optimization_method = 'xformers'
    elif cmd_opts.opt_sdp_no_mem_attention and can_use_sdp:
        print("Applying scaled dot product cross attention optimization (without memory efficient attention).")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_no_mem_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_no_mem_attnblock_forward
        optimization_method = 'sdp-no-mem'
    elif cmd_opts.opt_sdp_attention and can_use_sdp:
        print("Applying scaled dot product cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.scaled_dot_product_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sdp_attnblock_forward
        optimization_method = 'sdp'
    elif cmd_opts.opt_sub_quad_attention:
        print("Applying sub-quadratic cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.sub_quad_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.sub_quad_attnblock_forward
        optimization_method = 'sub-quadratic'
    elif cmd_opts.opt_split_attention_v1:
        print("Applying v1 cross attention optimization.")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_v1
        optimization_method = 'V1'
    elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention_invokeai or not cmd_opts.opt_split_attention and not torch.cuda.is_available()):
        print("Applying cross attention optimization (InvokeAI).")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward_invokeAI
        optimization_method = 'InvokeAI'
    elif not cmd_opts.disable_opt_split_attention and (cmd_opts.opt_split_attention or torch.cuda.is_available()):
        print("Applying cross attention optimization (Doggettx).")
        ldm.modules.attention.CrossAttention.forward = sd_hijack_optimizations.split_cross_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sd_hijack_optimizations.cross_attention_attnblock_forward
        optimization_method = 'Doggettx'

    return optimization_method
    selection = shared.opts.cross_attention_optimization
    if selection == "Automatic" and len(optimizers) > 0:
        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
    else:
        matching_optimizer = next(iter([x for x in optimizers if x.title() == selection]), None)

    if selection == "None":
        matching_optimizer = None
    elif matching_optimizer is None:
        matching_optimizer = optimizers[0]

    if matching_optimizer is not None:
        print(f"Applying optimization: {matching_optimizer.name}")
        matching_optimizer.apply()
        current_optimizer = matching_optimizer
        return current_optimizer.name
    else:
        return ''


def undo_optimizations():
    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
    ldm.modules.diffusionmodules.model.nonlinearity = diffusionmodules_model_nonlinearity
    ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
    ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward


@@ -169,7 +168,11 @@ class StableDiffusionModelHijack:
        if m.cond_stage_key == "edit":
            sd_hijack_unet.hijack_ddpm_edit()

        try:
            self.optimization_method = apply_optimizations()
        except Exception as e:
            errors.display(e, "applying cross attention optimization")
            undo_optimizations()

        self.clip = m.cond_stage_model

@@ -223,6 +226,10 @@ class StableDiffusionModelHijack:

        return token_count, self.clip.get_target_prompt_token_count(token_count)

    def redo_hijack(self, m):
        self.undo_hijack(m)
        self.hijack(m)


class EmbeddingsWithFixes(torch.nn.Module):
    def __init__(self, wrapped, embeddings):
+122 −3
Original line number Diff line number Diff line
@@ -9,10 +9,129 @@ from torch import einsum
from ldm.util import default
from einops import rearrange

from modules import shared, errors, devices
from modules import shared, errors, devices, sub_quadratic_attention
from modules.hypernetworks import hypernetwork

from .sub_quadratic_attention import efficient_dot_product_attention
import ldm.modules.attention
import ldm.modules.diffusionmodules.model

diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward


class SdOptimization:
    name: str = None
    label: str | None = None
    cmd_opt: str | None = None
    priority: int = 0

    def title(self):
        if self.label is None:
            return self.name

        return f"{self.name} - {self.label}"

    def is_available(self):
        return True

    def apply(self):
        pass

    def undo(self):
        ldm.modules.attention.CrossAttention.forward = hypernetwork.attention_CrossAttention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = diffusionmodules_model_AttnBlock_forward


class SdOptimizationXformers(SdOptimization):
    name = "xformers"
    cmd_opt = "xformers"
    priority = 100

    def is_available(self):
        return shared.cmd_opts.force_enable_xformers or (shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0))

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = xformers_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = xformers_attnblock_forward


class SdOptimizationSdpNoMem(SdOptimization):
    name = "sdp-no-mem"
    label = "scaled dot product without memory efficient attention"
    cmd_opt = "opt_sdp_no_mem_attention"
    priority = 90

    def is_available(self):
        return hasattr(torch.nn.functional, "scaled_dot_product_attention") and callable(torch.nn.functional.scaled_dot_product_attention)

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_no_mem_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_no_mem_attnblock_forward


class SdOptimizationSdp(SdOptimizationSdpNoMem):
    name = "sdp"
    label = "scaled dot product"
    cmd_opt = "opt_sdp_attention"
    priority = 80

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = scaled_dot_product_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sdp_attnblock_forward


class SdOptimizationSubQuad(SdOptimization):
    name = "sub-quadratic"
    cmd_opt = "opt_sub_quad_attention"
    priority = 10

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = sub_quad_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = sub_quad_attnblock_forward


class SdOptimizationV1(SdOptimization):
    name = "V1"
    label = "original v1"
    cmd_opt = "opt_split_attention_v1"
    priority = 10


    def apply(self):
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1


class SdOptimizationInvokeAI(SdOptimization):
    name = "InvokeAI"
    cmd_opt = "opt_split_attention_invokeai"

    @property
    def priority(self):
        return 1000 if not torch.cuda.is_available() else 10

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_invokeAI


class SdOptimizationDoggettx(SdOptimization):
    name = "Doggettx"
    cmd_opt = "opt_split_attention"
    priority = 20

    def apply(self):
        ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward
        ldm.modules.diffusionmodules.model.AttnBlock.forward = cross_attention_attnblock_forward


def list_optimizers(res):
    res.extend([
        SdOptimizationXformers(),
        SdOptimizationSdpNoMem(),
        SdOptimizationSdp(),
        SdOptimizationSubQuad(),
        SdOptimizationV1(),
        SdOptimizationInvokeAI(),
        SdOptimizationDoggettx(),
    ])


if shared.cmd_opts.xformers or shared.cmd_opts.force_enable_xformers:
@@ -299,7 +418,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
        kv_chunk_size = k_tokens

    with devices.without_autocast(disable=q.dtype == v.dtype):
        return efficient_dot_product_attention(
        return sub_quadratic_attention.efficient_dot_product_attention(
            q,
            k,
            v,
+1 −0
Original line number Diff line number Diff line
@@ -418,6 +418,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
}))

options_templates.update(options_section(('optimizations', "Optimizations"), {
    "cross_attention_optimization": OptionInfo("Automatic", "Cross attention optimization", gr.Dropdown, lambda: {"choices": shared_items.cross_attention_optimizations()}),
    "s_min_uncond": OptionInfo(0, "Negative Guidance minimum sigma", gr.Slider, {"minimum": 0.0, "maximum": 4.0, "step": 0.01}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9177").info("skip negative prompt for some steps when the image is almost ready; 0=disable, higher=faster"),
    "token_merging_ratio": OptionInfo(0.0, "Token merging ratio", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).link("PR", "https://github.com/AUTOMATIC1111/stable-diffusion-webui/pull/9256").info("0=disable, higher=faster"),
    "token_merging_ratio_img2img": OptionInfo(0.0, "Token merging ratio for img2img", gr.Slider, {"minimum": 0.0, "maximum": 0.9, "step": 0.1}).info("only applies if non-zero and overrides above"),
Loading