Commit 7dbfd8a7 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

do not replace entire unet for the resolution hack

parent 2641d1b8
Loading
Loading
Loading
Loading
+3 −2
Original line number Original line Diff line number Diff line
@@ -11,7 +11,7 @@ import modules.textual_inversion.textual_inversion
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules import prompt_parser, devices, sd_hijack_optimizations, shared, sd_hijack_checkpoint
from modules.hypernetworks import hypernetwork
from modules.hypernetworks import hypernetwork
from modules.shared import opts, device, cmd_opts
from modules.shared import opts, device, cmd_opts
from modules import sd_hijack_clip, sd_hijack_open_clip
from modules import sd_hijack_clip, sd_hijack_open_clip, sd_hijack_unet


from modules.sd_hijack_optimizations import invokeAI_mps_available
from modules.sd_hijack_optimizations import invokeAI_mps_available


@@ -35,11 +35,12 @@ ldm.modules.attention.BasicTransformerBlock.ATTENTION_MODES["softmax-xformers"]
ldm.modules.attention.print = lambda *args: None
ldm.modules.attention.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None
ldm.modules.diffusionmodules.model.print = lambda *args: None



def apply_optimizations():
def apply_optimizations():
    undo_optimizations()
    undo_optimizations()


    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.model.nonlinearity = silu
    ldm.modules.diffusionmodules.openaimodel.UNetModel.forward = sd_hijack_optimizations.patched_unet_forward
    ldm.modules.diffusionmodules.openaimodel.th = sd_hijack_unet.th


    if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
    if cmd_opts.force_enable_xformers or (cmd_opts.xformers and shared.xformers_available and torch.version.cuda and (6, 0) <= torch.cuda.get_device_capability(shared.device) <= (9, 0)):
        print("Applying xformers cross attention optimization.")
        print("Applying xformers cross attention optimization.")
+0 −28
Original line number Original line Diff line number Diff line
@@ -313,31 +313,3 @@ def xformers_attnblock_forward(self, x):
        return x + out
        return x + out
    except NotImplementedError:
    except NotImplementedError:
        return cross_attention_attnblock_forward(self, x)
        return cross_attention_attnblock_forward(self, x)

def patched_unet_forward(self, x, timesteps=None, context=None, y=None,**kwargs):
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"
    hs = []
    t_emb = timestep_embedding(timesteps, self.model_channels, repeat_only=False)
    emb = self.time_embed(t_emb)

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.dtype)
    for module in self.input_blocks:
        h = module(h, emb, context)
        hs.append(h)
    h = self.middle_block(h, emb, context)
    for module in self.output_blocks:
        if h.shape[-2:] != hs[-1].shape[-2:]:
            h = F.interpolate(h, hs[-1].shape[-2:], mode="nearest")
        h = torch.cat([h, hs.pop()], dim=1)
        h = module(h, emb, context)
    h = h.type(x.dtype)
    if self.predict_codebook_ids:
        return self.id_predictor(h)
    else:
        return self.out(h)
+30 −0
Original line number Original line Diff line number Diff line
import torch


class TorchHijackForUnet:
    """
    This is torch, but with cat that resizes tensors to appropriate dimensions if they do not match;
    this makes it possible to create pictures with dimensions that are muliples of 8 rather than 64
    """

    def __getattr__(self, item):
        if item == 'cat':
            return self.cat

        if hasattr(torch, item):
            return getattr(torch, item)

        raise AttributeError("'{}' object has no attribute '{}'".format(type(self).__name__, item))

    def cat(self, tensors, *args, **kwargs):
        if len(tensors) == 2:
            a, b = tensors
            if a.shape[-2:] != b.shape[-2:]:
                a = torch.nn.functional.interpolate(a, b.shape[-2:], mode="nearest")

            tensors = (a, b)

        return torch.cat(tensors, *args, **kwargs)


th = TorchHijackForUnet()