Commit 87b50397 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add missing import, simplify code, use patches module for #13276

parent e309583f
Loading
Loading
Loading
Loading
+12 −7
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ import ldm.modules.midas as midas

from ldm.util import instantiate_from_config

from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack
from modules import paths, shared, modelloader, devices, script_callbacks, sd_vae, sd_disable_initialization, errors, hashes, sd_models_config, sd_unet, sd_models_xl, cache, extra_networks, processing, lowvram, sd_hijack, patches
from modules.timer import Timer
import tomesd
import numpy as np
@@ -130,6 +130,8 @@ except Exception:


def setup_model():
    """called once at startup to do various one-time tasks related to SD models"""

    os.makedirs(model_path, exist_ok=True)

    enable_midas_autodownload()
@@ -458,14 +460,17 @@ def enable_midas_autodownload():


def patch_given_betas():
    original_register_schedule = ldm.models.diffusion.ddpm.DDPM.register_schedule
    import ldm.models.diffusion.ddpm

    def patched_register_schedule(*args, **kwargs):
        if args[1] is not None and isinstance(args[1], ListConfig):
            modified_args = list(args)  # Convert args tuple to a list
            modified_args[1] = np.array(args[1])  # Modify the desired element
            args = tuple(modified_args)  # Convert the list back to a tuple
        """a modified version of register_schedule function that converts plain list from Omegaconf into numpy"""

        if isinstance(args[1], ListConfig):
            args = (args[0], np.array(args[1]), *args[2:])

        original_register_schedule(*args, **kwargs)
    ldm.models.diffusion.ddpm.DDPM.register_schedule = patched_register_schedule

    original_register_schedule = patches.patch(__name__, ldm.models.diffusion.ddpm.DDPM, 'register_schedule', patched_register_schedule)


def repair_config(sd_config):