Commit f85b7476 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

Merge branch 'hypertile-in-sample' into dev

parents fd8674a4 d2e0c1ca
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -174,5 +174,6 @@ Licenses for borrowed code can be found in `Settings -> Licenses` screen, and al
- TAESD - Ollin Boer Bohan - https://github.com/madebyollin/taesd
- LyCORIS - KohakuBlueleaf
- Restart sampling - lambertae - https://github.com/Newbeeer/diffusion_restart_sampling
- Hypertile - tfernd - https://github.com/tfernd/HyperTile
- Initial Gradio script - posted on 4chan by an Anonymous user. Thank you Anonymous user.
- (You)
+99 −122
Original line number Diff line number Diff line
"""
Hypertile module for splitting attention layers in SD-1.5 U-Net and SD-1.5 VAE
Warn: The patch works well only if the input image has a width and height that are multiples of 128
Author : @tfernd Github : https://github.com/tfernd/HyperTile
Original author: @tfernd Github: https://github.com/tfernd/HyperTile
"""

from __future__ import annotations

import functools
from dataclasses import dataclass
from typing import Callable
from typing_extensions import Literal

@@ -18,6 +21,19 @@ import random

from einops import rearrange


@dataclass
class HypertileParams:
    depth = 0
    layer_name = ""
    tile_size: int = 0
    swap_size: int = 0
    aspect_ratio: float = 1.0
    forward = None
    enabled = False



# TODO add SD-XL layers
DEPTH_LAYERS = {
    0: [
@@ -176,6 +192,7 @@ DEPTH_LAYERS_XL = {

RNG_INSTANCE = random.Random()


def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:
    """
    Returns a random divisor of value that
@@ -193,9 +210,12 @@ def random_divisor(value: int, min_value: int, /, max_options: int = 1) -> int:

    return ns[idx]


def set_hypertile_seed(seed: int) -> None:
    RNG_INSTANCE.seed(seed)


@functools.cache
def largest_tile_size_available(width: int, height: int) -> int:
    """
    Calculates the largest tile size available for a given width and height
@@ -207,6 +227,7 @@ def largest_tile_size_available(width:int, height:int) -> int:
        largest_tile_size_available *= 2
    return largest_tile_size_available


def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
    """
    Finds h and w such that h*w = hw and h/w = aspect_ratio
@@ -219,6 +240,7 @@ def iterative_closest_divisors(hw:int, aspect_ratio:float) -> tuple[int, int]:
    closest_pair = pairs[ratios.index(closest_ratio)] # closest pair of divisors to aspect_ratio
    return closest_pair


@cache
def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
    """
@@ -240,44 +262,28 @@ def find_hw_candidates(hw:int, aspect_ratio:float) -> tuple[int, int]:
            w = int(w_candidate)
    return h, w

@contextmanager
def split_attention(
    layer: nn.Module,
    /,
    aspect_ratio: float,  # width/height
    tile_size: int = 128,  # 128 for VAE
    swap_size: int = 1,  # 1 for VAE
    *,
    disable: bool = False,
    max_depth: Literal[0, 1, 2, 3] = 0,  # ! Try 0 or 1
    scale_depth: bool = True,  # scale the tile-size depending on the depth
    is_sdxl: bool = False,  # is the model SD-XL
):
    # Hijacks AttnBlock from ldm and Attention from diffusers

    if disable:
        logging.info(f"Attention for {layer.__class__.__qualname__} not splitted")
        yield
        return

    latent_tile_size = max(128, tile_size) // 8
def self_attn_forward(params: HypertileParams, scale_depth=True) -> Callable:

    def self_attn_forward(forward: Callable, depth: int, layer_name: str, module: nn.Module) -> Callable:
        @wraps(forward)
    @wraps(params.forward)
    def wrapper(*args, **kwargs):
        if not params.enabled:
            return params.forward(*args, **kwargs)

        latent_tile_size = max(128, params.tile_size) // 8
        x = args[0]

        # VAE
        if x.ndim == 4:
            b, c, h, w = x.shape

                nh = random_divisor(h, latent_tile_size, swap_size)
                nw = random_divisor(w, latent_tile_size, swap_size)
            nh = random_divisor(h, latent_tile_size, params.swap_size)
            nw = random_divisor(w, latent_tile_size, params.swap_size)

            if nh * nw > 1:
                x = rearrange(x, "b c (nh h) (nw w) -> (b nh nw) c h w", nh=nh, nw=nw)  # split into nh * nw tiles

                out = forward(x, *args[1:], **kwargs)
            out = params.forward(x, *args[1:], **kwargs)

            if nh * nw > 1:
                out = rearrange(out, "(b nh nw) c h w -> b c (nh h) (nw w)", nh=nh, nw=nw)
@@ -285,19 +291,17 @@ def split_attention(
        # U-Net
        else:
            hw: int = x.size(1)
                h, w = find_hw_candidates(hw, aspect_ratio)
                assert h * w == hw, f"Invalid aspect ratio {aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"

                factor = 2**depth if scale_depth else 1
                nh = random_divisor(h, latent_tile_size * factor, swap_size)
                nw = random_divisor(w, latent_tile_size * factor, swap_size)
            h, w = find_hw_candidates(hw, params.aspect_ratio)
            assert h * w == hw, f"Invalid aspect ratio {params.aspect_ratio} for input of shape {x.shape}, hw={hw}, h={h}, w={w}"

                module._split_sizes_hypertile.append((nh, nw))  # type: ignore
            factor = 2 ** params.depth if scale_depth else 1
            nh = random_divisor(h, latent_tile_size * factor, params.swap_size)
            nw = random_divisor(w, latent_tile_size * factor, params.swap_size)

            if nh * nw > 1:
                x = rearrange(x, "b (nh h nw w) c -> (b nh nw) (h w) c", h=h // nh, w=w // nw, nh=nh, nw=nw)

                out = forward(x, *args[1:], **kwargs)
            out = params.forward(x, *args[1:], **kwargs)

            if nh * nw > 1:
                out = rearrange(out, "(b nh nw) hw c -> b nh nw hw c", nh=nh, nw=nw)
@@ -307,65 +311,38 @@ def split_attention(

    return wrapper

    # Handle hijacking the forward method and recovering afterwards
    try:
        if is_sdxl:
            layers = DEPTH_LAYERS_XL
        else:
            layers = DEPTH_LAYERS
        for depth in range(max_depth + 1):
            for layer_name, module in layer.named_modules():

def hypertile_hook_model(model: nn.Module, width, height, *, enable=False, tile_size_max=128, swap_size=1, max_depth=3, is_sdxl=False):
    hypertile_layers = getattr(model, "__webui_hypertile_layers", None)
    if hypertile_layers is None:
        if not enable:
            return

        hypertile_layers = {}
        layers = DEPTH_LAYERS_XL if is_sdxl else DEPTH_LAYERS

        for depth in range(4):
            for layer_name, module in model.named_modules():
                if any(layer_name.endswith(try_name) for try_name in layers[depth]):
                    # print input shape for debugging
                    logging.debug(f"HyperTile hijacking attention layer at depth {depth}: {layer_name}")
                    # hijack
                    module._original_forward_hypertile = module.forward
                    module.forward = self_attn_forward(module.forward, depth, layer_name, module)
                    module._split_sizes_hypertile = []
        yield
    finally:
        for layer_name, module in layer.named_modules():
            # remove hijack
            if hasattr(module, "_original_forward_hypertile"):
                if module._split_sizes_hypertile:
                    logging.debug(f"layer {layer_name} splitted with ({module._split_sizes_hypertile})")
                # recover
                module.forward = module._original_forward_hypertile
                del module._original_forward_hypertile
                del module._split_sizes_hypertile

def hypertile_context_vae(model:nn.Module, aspect_ratio:float, tile_size:int, opts):
    """
    Returns context manager for VAE
    """
    enabled = opts.hypertile_split_vae_attn
    swap_size = opts.hypertile_swap_size_vae
    max_depth = opts.hypertile_max_depth_vae
    tile_size_max = opts.hypertile_max_tile_vae
    return split_attention(
        model,
        aspect_ratio=aspect_ratio,
        tile_size=min(tile_size, tile_size_max),
        swap_size=swap_size,
        disable=not enabled,
        max_depth=max_depth,
        is_sdxl=False,
    )

def hypertile_context_unet(model:nn.Module, aspect_ratio:float, tile_size:int, opts, is_sdxl:bool):
    """
    Returns context manager for U-Net
    """
    enabled = opts.hypertile_split_unet_attn
    swap_size = opts.hypertile_swap_size_unet
    max_depth = opts.hypertile_max_depth_unet
    tile_size_max = opts.hypertile_max_tile_unet
    return split_attention(
        model,
        aspect_ratio=aspect_ratio,
        tile_size=min(tile_size, tile_size_max),
        swap_size=swap_size,
        disable=not enabled,
        max_depth=max_depth,
        is_sdxl=is_sdxl,
    )
                    params = HypertileParams()
                    module.__webui_hypertile_params = params
                    params.forward = module.forward
                    params.depth = depth
                    params.layer_name = layer_name
                    module.forward = self_attn_forward(params)

                    hypertile_layers[layer_name] = 1

        model.__webui_hypertile_layers = hypertile_layers

    aspect_ratio = width / height
    tile_size = min(largest_tile_size_available(width, height), tile_size_max)

    for layer_name, module in model.named_modules():
        if layer_name in hypertile_layers:
            params = module.__webui_hypertile_params

            params.tile_size = tile_size
            params.swap_size = swap_size
            params.aspect_ratio = aspect_ratio
            params.enabled = enable and params.depth <= max_depth
+73 −0
Original line number Diff line number Diff line
import hypertile
from modules import scripts, script_callbacks, shared


class ScriptHypertile(scripts.Script):
    name = "Hypertile"

    def title(self):
        return self.name

    def show(self, is_img2img):
        return scripts.AlwaysVisible

    def process(self, p, *args):
        hypertile.set_hypertile_seed(p.all_seeds[0])

        configure_hypertile(p.width, p.height, enable_unet=shared.opts.hypertile_enable_unet)

    def before_hr(self, p, *args):
        configure_hypertile(p.hr_upscale_to_x, p.hr_upscale_to_y, enable_unet=shared.opts.hypertile_enable_unet_secondpass or shared.opts.hypertile_enable_unet)


def configure_hypertile(width, height, enable_unet=True):
    hypertile.hypertile_hook_model(
        shared.sd_model.first_stage_model,
        width,
        height,
        swap_size=shared.opts.hypertile_swap_size_vae,
        max_depth=shared.opts.hypertile_max_depth_vae,
        tile_size_max=shared.opts.hypertile_max_tile_vae,
        enable=shared.opts.hypertile_enable_vae,
    )

    hypertile.hypertile_hook_model(
        shared.sd_model.model,
        width,
        height,
        swap_size=shared.opts.hypertile_swap_size_unet,
        max_depth=shared.opts.hypertile_max_depth_unet,
        tile_size_max=shared.opts.hypertile_max_tile_unet,
        enable=enable_unet,
        is_sdxl=shared.sd_model.is_sdxl
    )


def on_ui_settings():
    import gradio as gr

    options = {
        "hypertile_explanation": shared.OptionHTML("""
    <a href='https://github.com/tfernd/HyperTile'>Hypertile</a> optimizes the self-attention layer within U-Net and VAE models,
    resulting in a reduction in computation time ranging from 1 to 4 times. The larger the generated image is, the greater the
    benefit.
    """),

        "hypertile_enable_unet": shared.OptionInfo(False, "Enable Hypertile U-Net").info("noticeable change in details of the generated picture; if enabled, overrides the setting below"),
        "hypertile_enable_unet_secondpass": shared.OptionInfo(False, "Enable Hypertile U-Net for hires fix second pass"),
        "hypertile_max_depth_unet": shared.OptionInfo(3, "Hypertile U-Net max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
        "hypertile_max_tile_unet": shared.OptionInfo(256, "Hypertile U-net max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
        "hypertile_swap_size_unet": shared.OptionInfo(3, "Hypertile U-net swap size", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),

        "hypertile_enable_vae": shared.OptionInfo(False, "Enable Hypertile VAE").info("minimal change in the generated picture"),
        "hypertile_max_depth_vae": shared.OptionInfo(3, "Hypertile VAE max depth", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}),
        "hypertile_max_tile_vae": shared.OptionInfo(128, "Hypertile VAE max tile size", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}),
        "hypertile_swap_size_vae": shared.OptionInfo(3, "Hypertile VAE swap size ", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}),
    }

    for name, opt in options.items():
        opt.section = ('hypertile', "Hypertile")
        shared.opts.add_option(name, opt)


script_callbacks.on_ui_settings(on_ui_settings)
+13 −24
Original line number Diff line number Diff line
@@ -24,7 +24,6 @@ from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.paths as paths
import modules.face_restoration
from modules.hypertile import set_hypertile_seed, largest_tile_size_available, hypertile_context_unet, hypertile_context_vae
import modules.images as images
import modules.styles
import modules.sd_models as sd_models
@@ -861,8 +860,7 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
                p.comment(comment)

            p.extra_generation_params.update(model_hijack.extra_generation_params)
            set_hypertile_seed(p.seed)
            # add batch size + hypertile status to information to reproduce the run

            if p.n_iter > 1:
                shared.state.job = f"Batch {n+1} out of {p.n_iter}"

@@ -874,7 +872,6 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            else:
                if opts.sd_vae_decode_method != 'Full':
                    p.extra_generation_params['VAE Decoder'] = opts.sd_vae_decode_method
                with hypertile_context_vae(p.sd_model.first_stage_model, aspect_ratio=p.width / p.height, tile_size=largest_tile_size_available(p.width, p.height), opts=shared.opts):
                x_samples_ddim = decode_latent_batch(p.sd_model, samples_ddim, target_device=devices.cpu, check_for_nans=True)

            x_samples_ddim = torch.stack(x_samples_ddim).float()
@@ -1141,25 +1138,23 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

    def sample(self, conditioning, unconditional_conditioning, seeds, subseeds, subseed_strength, prompts):
        self.sampler = sd_samplers.create_sampler(self.sampler_name, self.sd_model)
        aspect_ratio = self.width / self.height

        x = self.rng.next()
        tile_size = largest_tile_size_available(self.width, self.height)
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):
        samples = self.sampler.sample(self, x, conditioning, unconditional_conditioning, image_conditioning=self.txt2img_image_conditioning(x))
        del x

        if not self.enable_hr:
            return samples
        devices.torch_gc()

        if self.latent_scale_mode is None:
            with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            decoded_samples = torch.stack(decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)).to(dtype=torch.float32)
        else:
            decoded_samples = None

        with sd_models.SkipWritingToConfig():
            sd_models.reload_model_weights(info=self.hr_checkpoint_info)

        return self.sample_hr_pass(samples, decoded_samples, seeds, subseeds, subseed_strength, prompts)

    def sample_hr_pass(self, samples, decoded_samples, seeds, subseeds, subseed_strength, prompts):
@@ -1244,17 +1239,14 @@ class StableDiffusionProcessingTxt2Img(StableDiffusionProcessing):

        if self.scripts is not None:
            self.scripts.before_hr(self)
        tile_size = largest_tile_size_available(target_width, target_height)
        aspect_ratio = self.width / self.height
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):

        samples = self.sampler.sample_img2img(self, samples, noise, self.hr_c, self.hr_uc, steps=self.hr_second_pass_steps or self.steps, image_conditioning=image_conditioning)

        sd_models.apply_token_merging(self.sd_model, self.get_token_merging_ratio())

        self.sampler = None
        devices.torch_gc()
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):

        decoded_samples = decode_latent_batch(self.sd_model, samples, target_device=devices.cpu, check_for_nans=True)

        self.is_hr_pass = False
@@ -1532,10 +1524,7 @@ class StableDiffusionProcessingImg2Img(StableDiffusionProcessing):
        if self.initial_noise_multiplier != 1.0:
            self.extra_generation_params["Noise multiplier"] = self.initial_noise_multiplier
            x *= self.initial_noise_multiplier
        aspect_ratio = self.width / self.height
        tile_size = largest_tile_size_available(self.width, self.height)
        with hypertile_context_vae(self.sd_model.first_stage_model, aspect_ratio=aspect_ratio, tile_size=tile_size, opts=shared.opts):
            with hypertile_context_unet(self.sd_model.model, aspect_ratio=aspect_ratio, tile_size=tile_size, is_sdxl=shared.sd_model.is_sdxl, opts=shared.opts):

        samples = self.sampler.sample_img2img(self, self.init_latent, x, conditioning, unconditional_conditioning, image_conditioning=self.image_conditioning)

        if self.mask is not None:
+0 −8
Original line number Diff line number Diff line
@@ -201,14 +201,6 @@ options_templates.update(options_section(('optimizations', "Optimizations"), {
    "pad_cond_uncond": OptionInfo(False, "Pad prompt/negative prompt to be same length", infotext='Pad conds').info("improves performance when prompt and negative prompt have different lengths; changes seeds"),
    "persistent_cond_cache": OptionInfo(True, "Persistent cond cache").info("do not recalculate conds from prompts if prompts have not changed since previous calculation"),
    "batch_cond_uncond": OptionInfo(True, "Batch cond/uncond").info("do both conditional and unconditional denoising in one batch; uses a bit more VRAM during sampling, but improves speed; previously this was controlled by --always-batch-cond-uncond comandline argument"),
    "hypertile_split_unet_attn" : OptionInfo(False, "Split attention in Unet with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
    "hypertile_split_vae_attn": OptionInfo(False, "Split attention in VAE with HyperTile").link("Github", "https://github.com/tfernd/HyperTile").info("improves performance; changes behavior, but deterministic"),
    "hypertile_max_depth_vae" : OptionInfo(3, "Max depth for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
    "hypertile_max_depth_unet" : OptionInfo(3, "Max depth for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 3, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
    "hypertile_max_tile_vae" : OptionInfo(128, "Max tile size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
    "hypertile_max_tile_unet" : OptionInfo(256, "Max tile size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 512, "step": 16}).link("Github", "https://github.com/tfernd/HyperTile"),
    "hypertile_swap_size_unet": OptionInfo(3, "Swap size for Unet HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
    "hypertile_swap_size_vae": OptionInfo(3, "Swap size for VAE HyperTile hijack", gr.Slider, {"minimum": 0, "maximum": 6, "step": 1}).link("Github", "https://github.com/tfernd/HyperTile"),
}))

options_templates.update(options_section(('compatibility', "Compatibility"), {