Commit 339b5315 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

custom unet support

parent a6e653be
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -13,7 +13,7 @@ from skimage import exposure
from typing import Any, Dict, List

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, extra_networks, sd_vae_approx, scripts, sd_samplers_common, sd_unet
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -674,6 +674,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:
            if shared.opts.live_previews_enable and opts.show_progress_type == "Approx NN":
                sd_vae_approx.model()

            sd_unet.apply_unet()

        if state.job_count == -1:
            state.job_count = p.n_iter

+20 −0
Original line number Diff line number Diff line
@@ -111,6 +111,7 @@ callback_map = dict(
    callbacks_before_ui=[],
    callbacks_on_reload=[],
    callbacks_list_optimizers=[],
    callbacks_list_unets=[],
)


@@ -271,6 +272,18 @@ def list_optimizers_callback():
    return res


def list_unets_callback():
    res = []

    for c in callback_map['callbacks_list_unets']:
        try:
            c.callback(res)
        except Exception:
            report_exception(c, 'list_unets')

    return res


def add_callback(callbacks, fun):
    stack = [x for x in inspect.stack() if x.filename != __file__]
    filename = stack[0].filename if len(stack) > 0 else 'unknown file'
@@ -430,3 +443,10 @@ def on_list_optimizers(callback):
    to it."""

    add_callback(callback_map['callbacks_list_optimizers'], callback)


def on_list_unets(callback):
    """register a function to be called when UI is making a list of alternative options for unet.
    The function will be called with one argument, a list, and shall add objects of type modules.sd_unet.SdUnetOption to it."""

    add_callback(callback_map['callbacks_list_unets'], callback)
+14 −6
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from torch.nn.functional import silu
from types import MethodType

import modules.textual_inversion.textual_inversion
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors
from modules import devices, sd_hijack_optimizations, shared, script_callbacks, errors, sd_unet
from modules.hypernetworks import hypernetwork
from modules.shared import cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet, sd_hijack_xlmr, xlmr
@@ -43,7 +43,7 @@ def list_optimizers():
    optimizers.extend(new_optimizers)


def apply_optimizations():
def apply_optimizations(option=None):
    global current_optimizer

    undo_optimizations()
@@ -60,7 +60,7 @@ def apply_optimizations():
        current_optimizer.undo()
        current_optimizer = None

    selection = shared.opts.cross_attention_optimization
    selection = option or shared.opts.cross_attention_optimization
    if selection == "Automatic" and len(optimizers) > 0:
        matching_optimizer = next(iter([x for x in optimizers if x.cmd_opt and getattr(shared.cmd_opts, x.cmd_opt, False)]), optimizers[0])
    else:
@@ -72,12 +72,13 @@ def apply_optimizations():
        matching_optimizer = optimizers[0]

    if matching_optimizer is not None:
        print(f"Applying optimization: {matching_optimizer.name}... ", end='')
        print(f"Applying attention optimization: {matching_optimizer.name}... ", end='')
        matching_optimizer.apply()
        print("done.")
        current_optimizer = matching_optimizer
        return current_optimizer.name
    else:
        print("Disabling attention optimization")
        return ''


@@ -155,9 +156,9 @@ class StableDiffusionModelHijack:
    def __init__(self):
        self.embedding_db.add_embedding_dir(cmd_opts.embeddings_dir)

    def apply_optimizations(self):
    def apply_optimizations(self, option=None):
        try:
            self.optimization_method = apply_optimizations()
            self.optimization_method = apply_optimizations(option)
        except Exception as e:
            errors.display(e, "applying cross attention optimization")
            undo_optimizations()
@@ -194,6 +195,11 @@ class StableDiffusionModelHijack:

        self.layers = flatten(m)

        if not hasattr(ldm.modules.diffusionmodules.openaimodel, 'copy_of_UNetModel_forward_for_webui'):
            ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui = ldm.modules.diffusionmodules.openaimodel.UNetModel.forward

        ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_unet.UNetModel_forward

    def undo_hijack(self, m):
        if type(m.cond_stage_model) == xlmr.BertSeriesModelWithTransformation:
            m.cond_stage_model = m.cond_stage_model.wrapped
@@ -215,6 +221,8 @@ class StableDiffusionModelHijack:
        self.layers = None
        self.clip = None

        ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui

    def apply_circular(self, enable):
        if self.circular_enabled == enable:
            return
+3 −1
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas

from ldm.util import instantiate_from_config

from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet
from modules.sd_hijack_inpainting import do_inpainting_hijack
from modules.timer import Timer
import tomesd
@@ -532,6 +532,8 @@ def reload_model_weights(sd_model=None, info=None):
        if sd_model.sd_model_checkpoint == checkpoint_info.filename:
            return

        sd_unet.apply_unet("None")

        if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
            lowvram.send_everything_to_cpu()
        else:

modules/sd_unet.py

0 → 100644
+92 −0
Original line number Diff line number Diff line
import torch.nn
import ldm.modules.diffusionmodules.openaimodel

from modules import script_callbacks, shared, devices

unet_options = []
current_unet_option = None
current_unet = None


def list_unets():
    new_unets = script_callbacks.list_unets_callback()

    unet_options.clear()
    unet_options.extend(new_unets)


def get_unet_option(option=None):
    option = option or shared.opts.sd_unet

    if option == "None":
        return None

    if option == "Automatic":
        name = shared.sd_model.sd_checkpoint_info.model_name

        options = [x for x in unet_options if x.model_name == name]

        option = options[0].label if options else "None"

    return next(iter([x for x in unet_options if x.label == option]), None)


def apply_unet(option=None):
    global current_unet_option
    global current_unet

    new_option = get_unet_option(option)
    if new_option == current_unet_option:
        return

    if current_unet is not None:
        print(f"Dectivating unet: {current_unet.option.label}")
        current_unet.deactivate()

    current_unet_option = new_option
    if current_unet_option is None:
        current_unet = None

        if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
            shared.sd_model.model.diffusion_model.to(devices.device)

        return

    shared.sd_model.model.diffusion_model.to(devices.cpu)
    devices.torch_gc()

    current_unet = current_unet_option.create_unet()
    current_unet.option = current_unet_option
    print(f"Activating unet: {current_unet.option.label}")
    current_unet.activate()


class SdUnetOption:
    model_name = None
    """name of related checkpoint - this option will be selected automatically for unet if the name of checkpoint matches this"""

    label = None
    """name of the unet in UI"""

    def create_unet(self):
        """returns SdUnet object to be used as a Unet instead of built-in unet when making pictures"""
        raise NotImplementedError()


class SdUnet(torch.nn.Module):
    def forward(self, x, timesteps, context, *args, **kwargs):
        raise NotImplementedError()

    def activate(self):
        pass

    def deactivate(self):
        pass


def UNetModel_forward(self, x, timesteps=None, context=None, *args, **kwargs):
    if current_unet is not None:
        return current_unet.forward(x, timesteps, context, *args, **kwargs)

    return ldm.modules.diffusionmodules.openaimodel.copy_of_UNetModel_forward_for_webui(self, x, timesteps, context, *args, **kwargs)
Loading