Commit 7ba3923d authored by AUTOMATIC's avatar AUTOMATIC
Browse files

move DDIM/PLMS fix for OSX out of the file with inpainting code.

parent bb2e2c82
Loading
Loading
Loading
Loading
+23 −0
Original line number Diff line number Diff line
@@ -14,6 +14,8 @@ from modules.sd_hijack_optimizations import invokeAI_mps_available

import ldm.modules.attention
import ldm.modules.diffusionmodules.model
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms

attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
@@ -406,3 +408,24 @@ def add_circular_option_to_conv_2d():


model_hijack = StableDiffusionModelHijack()


def register_buffer(self, name, attr):
    """
    Fix register buffer bug for Mac OS.
    """

    if type(attr) == torch.Tensor:
        if attr.device != devices.device:

            # would this not break cuda when torch adds has_mps() to main version?
            if getattr(torch, 'has_mps', False):
                attr = attr.to(device="mps", dtype=torch.float32)
            else:
                attr = attr.to(devices.device)

    setattr(self, name, attr)


ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer
ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer
+1 −17
Original line number Diff line number Diff line
import torch
import modules.devices as devices

from einops import repeat
from omegaconf import ListConfig
@@ -317,20 +316,6 @@ class LatentInpaintDiffusion(LatentDiffusion):
        self.concat_keys = concat_keys


# =================================================================================================
# Fix register buffer bug for Mac OS, Viktor Tabori, viktor.doklist.com/start-here
# =================================================================================================
def register_buffer(self, name, attr):
    if type(attr) == torch.Tensor:
        optimal_type = devices.get_optimal_device()
        if attr.device != optimal_type:
           if getattr(torch, 'has_mps', False):
               attr = attr.to(device="mps", dtype=torch.float32)
           else:
               attr = attr.to(optimal_type)
    setattr(self, name, attr)


def should_hijack_inpainting(checkpoint_info):
    return str(checkpoint_info.filename).endswith("inpainting.ckpt") and not checkpoint_info.config.endswith("inpainting.yaml")

@@ -341,8 +326,7 @@ def do_inpainting_hijack():

    ldm.models.diffusion.ddim.DDIMSampler.p_sample_ddim = p_sample_ddim
    ldm.models.diffusion.ddim.DDIMSampler.sample = sample_ddim
    ldm.models.diffusion.ddim.DDIMSampler.register_buffer = register_buffer

    ldm.models.diffusion.plms.PLMSSampler.p_sample_plms = p_sample_plms
    ldm.models.diffusion.plms.PLMSSampler.sample = sample_plms
    ldm.models.diffusion.plms.PLMSSampler.register_buffer = register_buffer