Commit 2ab64ec8 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

emergency fix for #1199

parent 15f333a2
Loading
Loading
Loading
Loading
+13 −12
Original line number Diff line number Diff line
@@ -3,6 +3,7 @@ import numpy as np
import torch
import tqdm
from PIL import Image
import inspect

import k_diffusion.sampling
import ldm.models.diffusion.ddim
@@ -278,9 +279,9 @@ class KDiffusionSampler:
            k_diffusion.sampling.torch = TorchHijack(self)

        extra_params_kwargs = {}
        for val in self.extra_params:
          if hasattr(p,val):
            extra_params_kwargs[val] = getattr(p,val)
        for param_name in self.extra_params:
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
                extra_params_kwargs[param_name] = getattr(p, param_name)

        return self.func(self.model_wrap_cfg, xi, sigma_sched, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)

@@ -300,9 +301,9 @@ class KDiffusionSampler:
            k_diffusion.sampling.torch = TorchHijack(self)

        extra_params_kwargs = {}
        for val in self.extra_params:
          if hasattr(p,val):
            extra_params_kwargs[val] = getattr(p,val)
        for param_name in self.extra_params:
            if hasattr(p, param_name) and param_name in inspect.signature(self.func).parameters:
                extra_params_kwargs[param_name] = getattr(p, param_name)

        samples = self.func(self.model_wrap_cfg, x, sigmas, extra_args={'cond': conditioning, 'uncond': unconditional_conditioning, 'cond_scale': p.cfg_scale}, disable=False, callback=self.callback_state, **extra_params_kwargs)