Commit 3d341ebc authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Merge branch 'dev' into test-fp8

parents 40ac134c e4410326
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ jobs:
          #     not to have GHA download an (at the time of writing) 4 GB cache
          #     of PyTorch and other dependencies.
      - name: Install Ruff
        run: pip install ruff==0.0.272
        run: pip install ruff==0.1.6
      - name: Run Ruff
        run: ruff .
  lint-js:
+4 −1
Original line number Diff line number Diff line
@@ -121,7 +121,9 @@ Alternatively, use online services (like Google Colab):
# Debian-based:
sudo apt install wget git python3 python3-venv libgl1 libglib2.0-0
# Red Hat-based:
sudo dnf install wget git python3
sudo dnf install wget git python3 gperftools-libs libglvnd-glx 
# openSUSE-based:
sudo zypper install wget git python3 libtcmalloc4 libglvnd
# Arch-based:
sudo pacman -S wget git python3
```
@@ -174,5 +176,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
+345 −0
Original line number Diff line number Diff line
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn: The patch works well only if the input image has a width and height that are multiples of 128
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
"""

from __future__ import annotations

import functools
from dataclasses import dataclass
from typing import Callable

from functools import wraps, cache

import math
import torch.nn as nn
import random

from einops import rearrange


@dataclass
class HypertileParams:
    depth = 0
    layer_name = ""
    tile_size: int = 0
    swap_size: int = 0
    aspect_ratio: float = 1.0
    forward = None
    enabled = False



# TODO add SD-XL layers
DEPTH_LAYERS = {
    0: [
        # SD 1.5 U-Net (diffusers)
        "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
        "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
        "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "input_blocks.1.1.transformer_blocks.0.attn1",
        "input_blocks.2.1.transformer_blocks.0.attn1",
        "output_blocks.9.1.transformer_blocks.0.attn1",
        "output_blocks.10.1.transformer_blocks.0.attn1",
        "output_blocks.11.1.transformer_blocks.0.attn1",
        # SD 1.5 VAE
        "decoder.mid_block.attentions.0",
        "decoder.mid.attn_1",
    ],
    1: [
        # SD 1.5 U-Net (diffusers)
        "down_blocks.1.attentions.0.transformer_blocks.0.attn1",
        "down_blocks.1.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.2.attentions.0.transformer_blocks.0.attn1",
        "up_blocks.2.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.2.attentions.2.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "input_blocks.4.1.transformer_blocks.0.attn1",
        "input_blocks.5.1.transformer_blocks.0.attn1",
        "output_blocks.6.1.transformer_blocks.0.attn1",
        "output_blocks.7.1.transformer_blocks.0.attn1",
        "output_blocks.8.1.transformer_blocks.0.attn1",
    ],
    2: [
        # SD 1.5 U-Net (diffusers)
        "down_blocks.2.attentions.0.transformer_blocks.0.attn1",
        "down_blocks.2.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.1.attentions.0.transformer_blocks.0.attn1",
        "up_blocks.1.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.1.attentions.2.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "input_blocks.7.1.transformer_blocks.0.attn1",
        "input_blocks.8.1.transformer_blocks.0.attn1",
        "output_blocks.3.1.transformer_blocks.0.attn1",
        "output_blocks.4.1.transformer_blocks.0.attn1",
        "output_blocks.5.1.transformer_blocks.0.attn1",
    ],
    3: [
        # SD 1.5 U-Net (diffusers)
        "mid_block.attentions.0.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "middle_block.1.transformer_blocks.0.attn1",
    ],
}
# XL layers, thanks for GitHub@gel-crabs for the help
DEPTH_LAYERS_XL = {
    0: [
        # SD 1.5 U-Net (diffusers)
        "down_blocks.0.attentions.0.transformer_blocks.0.attn1",
        "down_blocks.0.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.3.attentions.0.transformer_blocks.0.attn1",
        "up_blocks.3.attentions.1.transformer_blocks.0.attn1",
        "up_blocks.3.attentions.2.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "input_blocks.4.1.transformer_blocks.0.attn1",
        "input_blocks.5.1.transformer_blocks.0.attn1",
        "output_blocks.3.1.transformer_blocks.0.attn1",
        "output_blocks.4.1.transformer_blocks.0.attn1",
        "output_blocks.5.1.transformer_blocks.0.attn1",
        # SD 1.5 VAE
        "decoder.mid_block.attentions.0",
        "decoder.mid.attn_1",
    ],
    1: [
        # SD 1.5 U-Net (diffusers)
        #"down_blocks.1.attentions.0.transformer_blocks.0.attn1",
        #"down_blocks.1.attentions.1.transformer_blocks.0.attn1",
        #"up_blocks.2.attentions.0.transformer_blocks.0.attn1",
        #"up_blocks.2.attentions.1.transformer_blocks.0.attn1",
        #"up_blocks.2.attentions.2.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "input_blocks.4.1.transformer_blocks.1.attn1",
        "input_blocks.5.1.transformer_blocks.1.attn1",
        "output_blocks.3.1.transformer_blocks.1.attn1",
        "output_blocks.4.1.transformer_blocks.1.attn1",
        "output_blocks.5.1.transformer_blocks.1.attn1",
        "input_blocks.7.1.transformer_blocks.0.attn1",
        "input_blocks.8.1.transformer_blocks.0.attn1",
        "output_blocks.0.1.transformer_blocks.0.attn1",
        "output_blocks.1.1.transformer_blocks.0.attn1",
        "output_blocks.2.1.transformer_blocks.0.attn1",
        "input_blocks.7.1.transformer_blocks.1.attn1",
        "input_blocks.8.1.transformer_blocks.1.attn1",
        "output_blocks.0.1.transformer_blocks.1.attn1",
        "output_blocks.1.1.transformer_blocks.1.attn1",
        "output_blocks.2.1.transformer_blocks.1.attn1",
        "input_blocks.7.1.transformer_blocks.2.attn1",
        "input_blocks.8.1.transformer_blocks.2.attn1",
        "output_blocks.0.1.transformer_blocks.2.attn1",
        "output_blocks.1.1.transformer_blocks.2.attn1",
        "output_blocks.2.1.transformer_blocks.2.attn1",
        "input_blocks.7.1.transformer_blocks.3.attn1",
        "input_blocks.8.1.transformer_blocks.3.attn1",
        "output_blocks.0.1.transformer_blocks.3.attn1",
        "output_blocks.1.1.transformer_blocks.3.attn1",
        "output_blocks.2.1.transformer_blocks.3.attn1",
        "input_blocks.7.1.transformer_blocks.4.attn1",
        "input_blocks.8.1.transformer_blocks.4.attn1",
        "output_blocks.0.1.transformer_blocks.4.attn1",
        "output_blocks.1.1.transformer_blocks.4.attn1",
        "output_blocks.2.1.transformer_blocks.4.attn1",
        "input_blocks.7.1.transformer_blocks.5.attn1",
        "input_blocks.8.1.transformer_blocks.5.attn1",
        "output_blocks.0.1.transformer_blocks.5.attn1",
        "output_blocks.1.1.transformer_blocks.5.attn1",
        "output_blocks.2.1.transformer_blocks.5.attn1",
        "input_blocks.7.1.transformer_blocks.6.attn1",
        "input_blocks.8.1.transformer_blocks.6.attn1",
        "output_blocks.0.1.transformer_blocks.6.attn1",
        "output_blocks.1.1.transformer_blocks.6.attn1",
        "output_blocks.2.1.transformer_blocks.6.attn1",
        "input_blocks.7.1.transformer_blocks.7.attn1",
        "input_blocks.8.1.transformer_blocks.7.attn1",
        "output_blocks.0.1.transformer_blocks.7.attn1",
        "output_blocks.1.1.transformer_blocks.7.attn1",
        "output_blocks.2.1.transformer_blocks.7.attn1",
        "input_blocks.7.1.transformer_blocks.8.attn1",
        "input_blocks.8.1.transformer_blocks.8.attn1",
        "output_blocks.0.1.transformer_blocks.8.attn1",
        "output_blocks.1.1.transformer_blocks.8.attn1",
        "output_blocks.2.1.transformer_blocks.8.attn1",
        "input_blocks.7.1.transformer_blocks.9.attn1",
        "input_blocks.8.1.transformer_blocks.9.attn1",
        "output_blocks.0.1.transformer_blocks.9.attn1",
        "output_blocks.1.1.transformer_blocks.9.attn1",
        "output_blocks.2.1.transformer_blocks.9.attn1",
    ],
    2: [
        # SD 1.5 U-Net (diffusers)
        "mid_block.attentions.0.transformer_blocks.0.attn1",
        # SD 1.5 U-Net (ldm)
        "middle_block.1.transformer_blocks.0.attn1",
        "middle_block.1.transformer_blocks.1.attn1",
        "middle_block.1.transformer_blocks.2.attn1",
        "middle_block.1.transformer_blocks.3.attn1",
        "middle_block.1.transformer_blocks.4.attn1",
        "middle_block.1.transformer_blocks.5.attn1",
        "middle_block.1.transformer_blocks.6.attn1",
        "middle_block.1.transformer_blocks.7.attn1",
        "middle_block.1.transformer_blocks.8.attn1",
        "middle_block.1.transformer_blocks.9.attn1",
    ],
    3 : [] # TODO - separate layers for SD-XL
}


RNG_INSTANCE = random.Random()


def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
    """
    Returns a random divisor of value that
        x * min_value <= value
    if max_options is 1, the behavior is deterministic
    """
    min_value = min(min_value, value)

    # All big divisors of value (inclusive)
    divisors = [i for i in range(min_value, value + 1) if value % i == 0] # divisors in small -> big order

    ns = [value // i for i in divisors[:max_options]]  # has at least 1 element # big -> small order

    idx = RNG_INSTANCE.randint(0, len(ns) - 1)

    return ns[idx]


def set_hypertile_seed(seed: int) -> None:
    RNG_INSTANCE.seed(seed)


@functools.cache
def largest_tile_size_available(width: int, height: int) -> int:
    """
    Calculates the largest tile size available for a given width and height
    Tile size is always a power of 2
    """
    gcd = math.gcd(width, height)
    largest_tile_size_available = 1
    while gcd % (largest_tile_size_available * 2) == 0:
        largest_tile_size_available *= 2
    return largest_tile_size_available


def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
    """
    Finds h and w such that h*w = hw and h/w = aspect_ratio
    We check all possible divisors of hw and return the closest to the aspect ratio
    """
    divisors = [i for i in range(2, hw + 1) if hw % i == 0] # all divisors of hw
    pairs = [(i, hw // i) for i in divisors] # all pairs of divisors of hw
    ratios = [w/h for h, w in pairs] # all ratios of pairs of divisors of hw
    closest_ratio = min(ratios, key=lambda x: abs(x - aspect_ratio)) # closest ratio to aspect_ratio
    closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
    return closest_pair


@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
    """
    Finds h and w such that h*w = hw and h/w = aspect_ratio
    """
    h, w = round(math.sqrt(hw * aspect_ratio)), round(math.sqrt(hw / aspect_ratio))
    # find h and w such that h*w = hw and h/w = aspect_ratio
    if h * w != hw:
        w_candidate = hw / h
        # check if w is an integer
        if not w_candidate.is_integer():
            h_candidate = hw / w
            # check if h is an integer
            if not h_candidate.is_integer():
                return iterative_closest_divisors(hw, aspect_ratio)
            else:
                h = int(h_candidate)
        else:
            w = int(w_candidate)
    return h, w


def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:

    @wraps(params.forward)
    def wrapper(*args, **kwargs):
        if not params.enabled:
            return params.forward(*args, **kwargs)

        latent_tile_size = max(128, params.tile_size) // 8
        x = args[0]

        # VAE
        if x.ndim == 4:
            b, c, h, w = x.shape

            nh = random_divisor(h, latent_tile_size, params.swap_size)
            nw = random_divisor(w, latent_tile_size, params.swap_size)

            if nh * nw > 1:
                x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw)  # split into nh * nw tiles

            out = params.forward(x, *args[1:], **kwargs)

            if nh * nw > 1:
                out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)

        # U-Net
        else:
            hw: int = x.size(1)
            h, w = find_hw_candidates(hw, params.aspect_ratio)
            assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"

            factor = 2 ** params.depth if scale_depth else 1
            nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
            nw = random_divisor(w, latent_tile_size * factor, params.swap_size)

            if nh * nw > 1:
                x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

            out = params.forward(x, *args[1:], **kwargs)

            if nh * nw > 1:
                out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
                out = rearrange(out, "b nh nw (h w) c -> b (nh h nw w) c", h=h // nh, w=w // nw)

        return out

    return wrapper


def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
    hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
    if hypertile_layers is None:
        if not enable:
            return

        hypertile_layers = {}
        layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS

        for depth in range(4):
            for layer_name, module in model.named_modules():
                if any(layer_name.endswith(try_name) for try_name in layers[depth]):
                    params = HypertileParams()
                    module.__webui_hypertile_params = params
                    params.forward = module.forward
                    params.depth = depth
                    params.layer_name = layer_name
                    module.forward = self_attn_forward(params)

                    hypertile_layers[layer_name] = 1

        model.__webui_hypertile_layers = hypertile_layers

    aspect_ratio = width / height
    tile_size = min(largest_tile_size_available(width, height), tile_size_max)

    for layer_name, module in model.named_modules():
        if layer_name in hypertile_layers:
            params = module.__webui_hypertile_params

            params.tile_size = tile_size
            params.swap_size = swap_size
            params.aspect_ratio = aspect_ratio
            params.enabled = enable and params.depth <= max_depth
+73 −0
Original line number Diff line number Diff line
import hypertile
from modules import scripts, script_callbacks, shared


class ScriptHypertile(scripts.Script):
    name = "Hypertile"

    def title(self):
        return self.name

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def process(self, p, *args):
        hypertile.set_hypertile_seed(p.all_seeds[0])

        configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)

    def before_hr(self, p, *args):
        configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)


def configure_hypertile(width, height, enable_unet=True):
    hypertile.hypertile_hook_model(
        shared.sd_model.first_stage_model,
        width,
        height,
        swap_size=shared.opts.hypertile_swap_size_vae,
        max_depth=shared.opts.hypertile_max_depth_vae,
        tile_size_max=shared.opts.hypertile_max_tile_vae,
        enable=shared.opts.hypertile_enable_vae,
    )

    hypertile.hypertile_hook_model(
        shared.sd_model.model,
        width,
        height,
        swap_size=shared.opts.hypertile_swap_size_unet,
        max_depth=shared.opts.hypertile_max_depth_unet,
        tile_size_max=shared.opts.hypertile_max_tile_unet,
        enable=enable_unet,
        is_sdxl=shared.sd_model.is_sdxl
    )


def on_ui_settings():
    import gradio as gr

    options = {
        "hypertile_explanation": shared.OptionHTML("""
    <a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
    resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
    benefit.
    """),

        "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
        "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
        "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
        "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
        "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),

        "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
        "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
        "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
        "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
    }

    for name, opt in options.items():
        opt.section = ('hypertile', "Hypertile")
        shared.opts.add_option(name, opt)


script_callbacks.on_ui_settings(on_ui_settings)
+16 −2
Original line number Diff line number Diff line
@@ -6,6 +6,21 @@ import traceback
exception_records = []


def format_traceback(tb):
    return [[f"{x.filename}, line {x.lineno}, {x.name}", x.line] for x in traceback.extract_tb(tb)]


def format_exception(e, tb):
    return {"exception": str(e), "traceback": format_traceback(tb)}


def get_exceptions():
    try:
        return list(reversed(exception_records))
    except Exception as e:
        return str(e)


def record_exception():
    _, e, tb = sys.exc_info()
    if e is None:
@@ -14,8 +29,7 @@ def record_exception():
    if exception_records and exception_records[-1] == e:
        return

    from modules import sysinfo
    exception_records.append(sysinfo.format_exception(e, tb))
    exception_records.append(format_exception(e, tb))

    if len(exception_records) > 5:
        exception_records.pop(0)
Loading