Commit e3b53fd2 authored by brkirch's avatar brkirch
Browse files

Add UI setting for upcasting attention to float32

Adds "Upcast cross attention layer to float32" option in Stable Diffusion settings. This allows for generating images using SD 2.1 models without --no-half or xFormers.

In order to make upcasting cross attention layer optimizations possible it is necessary to indent several sections of code in sd_hijack_optimizations.py so that a context manager can be used to disable autocast. Also, even though Stable Diffusion (and Diffusers) only upcast q and k, unfortunately my findings were that most of the cross attention layer optimizations could not function unless v is upcast also.
parent 84d9ce30
Loading
Loading
Loading
Loading
+5 −1
Original line number Original line Diff line number Diff line
@@ -108,6 +108,10 @@ def autocast(disable=False):
    return torch.autocast("cuda")
    return torch.autocast("cuda")




def without_autocast(disable=False):
    return torch.autocast("cuda", enabled=False) if torch.is_autocast_enabled() and not disable else contextlib.nullcontext()


class NansException(Exception):
class NansException(Exception):
    pass
    pass


@@ -125,7 +129,7 @@ def test_for_nans(x, where):
        message = "A tensor with all NaNs was produced in Unet."
        message = "A tensor with all NaNs was produced in Unet."


        if not shared.cmd_opts.no_half:
        if not shared.cmd_opts.no_half:
            message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try using --no-half commandline argument to fix this."
            message += " This could be either because there's not enough precision to represent the picture, or because your video card does not support half type. Try setting the \"Upcast cross attention layer to float32\" option in Settings > Stable Diffusion or using the --no-half commandline argument to fix this."


    elif where == "vae":
    elif where == "vae":
        message = "A tensor with all NaNs was produced in VAE."
        message = "A tensor with all NaNs was produced in VAE."
+1 −1
Original line number Original line Diff line number Diff line
@@ -611,7 +611,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if p.n_iter > 1:
            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"


            with devices.autocast(disable=devices.unet_needs_upcast):
            with devices.without_autocast() if devices.unet_needs_upcast else devices.autocast():
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)
                samples_ddim = p.sample(conditioning=c, unconditional_conditioning=uc, seeds=seeds, subseeds=subseeds, subseed_strength=p.subseed_strength, prompts=prompts)


            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
            x_samples_ddim = [decode_first_stage(p.sd_model, samples_ddim[i:i+1].to(dtype=devices.dtype_vae))[0].cpu() for i in range(samples_ddim.size(0))]
+99 −60
Original line number Original line Diff line number Diff line
@@ -9,7 +9,7 @@ from torch import einsum
from ldm.util import default
from ldm.util import default
from einops import rearrange
from einops import rearrange


from modules import shared, errors
from modules import shared, errors, devices
from modules.hypernetworks import hypernetwork
from modules.hypernetworks import hypernetwork


from .sub_quadratic_attention import efficient_dot_product_attention
from .sub_quadratic_attention import efficient_dot_product_attention
@@ -52,7 +52,12 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q_in, k_in, v_in))
    del q_in, k_in, v_in
    del q_in, k_in, v_in


    r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device)
    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k, v = q.float(), k.float(), v.float()

    with devices.without_autocast(disable=not shared.opts.upcast_attn):
        r1 = torch.zeros(q.shape[0], q.shape[1], v.shape[2], device=q.device, dtype=q.dtype)
        for i in range(0, q.shape[0], 2):
        for i in range(0, q.shape[0], 2):
            end = i + 2
            end = i + 2
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
            s1 = einsum('b i d, b j d -> b i j', q[i:end], k[i:end])
@@ -65,6 +70,8 @@ def split_cross_attention_forward_v1(self, x, context=None, mask=None):
            del s2
            del s2
        del q, k, v
        del q, k, v


    r1 = r1.to(dtype)

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1
    del r1


@@ -82,7 +89,12 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
    k_in = self.to_k(context_k)
    k_in = self.to_k(context_k)
    v_in = self.to_v(context_v)
    v_in = self.to_v(context_v)


    k_in *= self.scale
    dtype = q_in.dtype
    if shared.opts.upcast_attn:
        q_in, k_in, v_in = q_in.float(), k_in.float(), v_in if v_in.device.type == 'mps' else v_in.float()

    with devices.without_autocast(disable=not shared.opts.upcast_attn):
        k_in = k_in * self.scale
    
    
        del context, x
        del context, x
    
    
@@ -122,6 +134,8 @@ def split_cross_attention_forward(self, x, context=None, mask=None):
    
    
        del q, k, v
        del q, k, v


    r1 = r1.to(dtype)

    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    r2 = rearrange(r1, '(b h) n d -> b n (h d)', h=h)
    del r1
    del r1


@@ -204,12 +218,20 @@ def split_cross_attention_forward_invokeAI(self, x, context=None, mask=None):
    context = default(context, x)
    context = default(context, x)


    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
    context_k, context_v = hypernetwork.apply_hypernetworks(shared.loaded_hypernetworks, context)
    k = self.to_k(context_k) * self.scale
    k = self.to_k(context_k)
    v = self.to_v(context_v)
    v = self.to_v(context_v)
    del context, context_k, context_v, x
    del context, context_k, context_v, x


    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k, v = q.float(), k.float(), v if v.device.type == 'mps' else v.float()

    with devices.without_autocast(disable=not shared.opts.upcast_attn):
        k = k * self.scale
    
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> (b h) n d', h=h), (q, k, v))
        r = einsum_op(q, k, v)
        r = einsum_op(q, k, v)
    r = r.to(dtype)
    return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))
    return self.to_out(rearrange(r, '(b h) n d -> b n (h d)', h=h))


# -- End of code from https://github.com/invoke-ai/InvokeAI --
# -- End of code from https://github.com/invoke-ai/InvokeAI --
@@ -234,8 +256,14 @@ def sub_quad_attention_forward(self, x, context=None, mask=None):
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    k = k.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)
    v = v.unflatten(-1, (h, -1)).transpose(1,2).flatten(end_dim=1)


    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k = q.float(), k.float()

    x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)
    x = sub_quad_attention(q, k, v, q_chunk_size=shared.cmd_opts.sub_quad_q_chunk_size, kv_chunk_size=shared.cmd_opts.sub_quad_kv_chunk_size, chunk_threshold=shared.cmd_opts.sub_quad_chunk_threshold, use_checkpoint=self.training)


    x = x.to(dtype)

    x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)
    x = x.unflatten(0, (-1, h)).transpose(1,2).flatten(start_dim=2)


    out_proj, dropout = self.to_out
    out_proj, dropout = self.to_out
@@ -268,6 +296,7 @@ def sub_quad_attention(q, k, v, q_chunk_size=1024, kv_chunk_size=None, kv_chunk_
        query_chunk_size = q_tokens
        query_chunk_size = q_tokens
        kv_chunk_size = k_tokens
        kv_chunk_size = k_tokens


    with devices.without_autocast(disable=q.dtype == v.dtype):
        return efficient_dot_product_attention(
        return efficient_dot_product_attention(
            q,
            q,
            k,
            k,
@@ -306,8 +335,14 @@ def xformers_attention_forward(self, x, context=None, mask=None):
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
    q, k, v = map(lambda t: rearrange(t, 'b n (h d) -> b n h d', h=h), (q_in, k_in, v_in))
    del q_in, k_in, v_in
    del q_in, k_in, v_in


    dtype = q.dtype
    if shared.opts.upcast_attn:
        q, k = q.float(), k.float()

    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))
    out = xformers.ops.memory_efficient_attention(q, k, v, attn_bias=None, op=get_xformers_flash_attention_op(q, k, v))


    out = out.to(dtype)

    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
    out = rearrange(out, 'b n h d -> b n (h d)', h=h)
    return self.to_out(out)
    return self.to_out(out)


@@ -378,10 +413,14 @@ def xformers_attnblock_forward(self, x):
        v = self.v(h_)
        v = self.v(h_)
        b, c, h, w = q.shape
        b, c, h, w = q.shape
        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
        q, k, v = map(lambda t: rearrange(t, 'b c h w -> b (h w) c'), (q, k, v))
        dtype = q.dtype
        if shared.opts.upcast_attn:
            q, k = q.float(), k.float()
        q = q.contiguous()
        q = q.contiguous()
        k = k.contiguous()
        k = k.contiguous()
        v = v.contiguous()
        v = v.contiguous()
        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
        out = xformers.ops.memory_efficient_attention(q, k, v, op=get_xformers_flash_attention_op(q, k, v))
        out = out.to(dtype)
        out = rearrange(out, 'b (h w) c -> b c h w', h=h)
        out = rearrange(out, 'b (h w) c -> b c h w', h=h)
        out = self.proj_out(out)
        out = self.proj_out(out)
        return x + out
        return x + out
+1 −0
Original line number Original line Diff line number Diff line
@@ -410,6 +410,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
    "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
    "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
    "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
    "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
    "extra_networks_default_multiplier": OptionInfo(1.0, "Multiplier for extra networks", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.01}),
    "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
}))
}))


options_templates.update(options_section(('compatibility', "Compatibility"), {
options_templates.update(options_section(('compatibility', "Compatibility"), {
+2 −2
Original line number Original line Diff line number Diff line
@@ -67,7 +67,7 @@ def _summarize_chunk(
    max_score, _ = torch.max(attn_weights, -1, keepdim=True)
    max_score, _ = torch.max(attn_weights, -1, keepdim=True)
    max_score = max_score.detach()
    max_score = max_score.detach()
    exp_weights = torch.exp(attn_weights - max_score)
    exp_weights = torch.exp(attn_weights - max_score)
    exp_values = torch.bmm(exp_weights, value)
    exp_values = torch.bmm(exp_weights, value) if query.device.type == 'mps' else torch.bmm(exp_weights, value.to(exp_weights.dtype)).to(value.dtype)
    max_score = max_score.squeeze(-1)
    max_score = max_score.squeeze(-1)
    return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)
    return AttnChunk(exp_values, exp_weights.sum(dim=-1), max_score)


@@ -129,7 +129,7 @@ def _get_attention_scores_no_kv_chunking(
    )
    )
    attn_probs = attn_scores.softmax(dim=-1)
    attn_probs = attn_scores.softmax(dim=-1)
    del attn_scores
    del attn_scores
    hidden_states_slice = torch.bmm(attn_probs, value)
    hidden_states_slice = torch.bmm(attn_probs, value) if query.device.type == 'mps' else torch.bmm(attn_probs, value.to(attn_probs.dtype)).to(value.dtype)
    return hidden_states_slice
    return hidden_states_slice