Commit f55a7e04 authored by RcINS's avatar RcINS
Browse files

Fix error when batch count > 1

parent 9e27af76
Loading
Loading
Loading
Loading
+5 −4
Original line number Diff line number Diff line
@@ -269,14 +269,15 @@ class KDiffusionSampler:

        return sigmas

    def create_noise_sampler(self, x, sigmas, seeds):
    def create_noise_sampler(self, x, sigmas, p):
        """For DPM++ SDE: manually create noise sampler to enable deterministic results across different batch sizes"""
        if shared.opts.no_dpmpp_sde_batch_determinism:
            return None

        from k_diffusion.sampling import BrownianTreeNoiseSampler
        sigma_min, sigma_max = sigmas[sigmas > 0].min(), sigmas.max()
        return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=seeds)
        current_iter_seeds = p.all_seeds[p.iteration * p.batch_size:(p.iteration + 1) * p.batch_size]
        return BrownianTreeNoiseSampler(x, sigma_min, sigma_max, seed=current_iter_seeds)

    def sample_img2img(self, p, x, noise, conditioning, unconditional_conditioning, steps=None, image_conditioning=None):
        steps, t_enc = sd_samplers_common.setup_img2img_steps(p, steps)
@@ -302,7 +303,7 @@ class KDiffusionSampler:
            extra_params_kwargs['sigmas'] = sigma_sched

        if self.funcname == 'sample_dpmpp_sde':
            noise_sampler = self.create_noise_sampler(x, sigmas, p.all_seeds)
            noise_sampler = self.create_noise_sampler(x, sigmas, p)
            extra_params_kwargs['noise_sampler'] = noise_sampler

        self.model_wrap_cfg.init_latent = x
@@ -337,7 +338,7 @@ class KDiffusionSampler:
            extra_params_kwargs['sigmas'] = sigmas

        if self.funcname == 'sample_dpmpp_sde':
            noise_sampler = self.create_noise_sampler(x, sigmas, p.all_seeds)
            noise_sampler = self.create_noise_sampler(x, sigmas, p)
            extra_params_kwargs['noise_sampler'] = noise_sampler

        self.last_latent = x