Commit fca42949 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

rework torchsde._brownian.brownian_interval replacement to use...

rework torchsde._brownian.brownian_interval replacement to use device.randn_local and respect the NV setting.
parent 84b6fcd0
Loading
Loading
Loading
Loading
+38 −6
Original line number Diff line number Diff line
@@ -71,14 +71,17 @@ def enable_tf32():
        torch.backends.cudnn.allow_tf32 = True



errors.run(enable_tf32, "Enabling TF32")

cpu = torch.device("cpu")
device = device_interrogate = device_gfpgan = device_esrgan = device_codeformer = None
dtype = torch.float16
dtype_vae = torch.float16
dtype_unet = torch.float16
cpu: torch.device = torch.device("cpu")
device: torch.device = None
device_interrogate: torch.device = None
device_gfpgan: torch.device = None
device_esrgan: torch.device = None
device_codeformer: torch.device = None
dtype: torch.dtype = torch.float16
dtype_vae: torch.dtype = torch.float16
dtype_unet: torch.dtype = torch.float16
unet_needs_upcast = False


@@ -94,6 +97,10 @@ nv_rng = None


def randn(seed, shape):
    """Generate a tensor with random numbers from a normal distribution using seed.

    Uses the seed parameter to set the global torch seed; to generate more with that seed, use randn_like/randn_without_seed."""

    from modules.shared import opts

    manual_seed(seed)
@@ -107,7 +114,27 @@ def randn(seed, shape):
    return torch.randn(shape, device=device)


def randn_local(seed, shape):
    """Generate a tensor with random numbers from a normal distribution using seed.

    Does not change the global random number generator. You can only generate the seed's first tensor using this function."""

    from modules.shared import opts

    if opts.randn_source == "NV":
        rng = rng_philox.Generator(seed)
        return torch.asarray(rng.randn(shape), device=device)

    local_device = cpu if opts.randn_source == "CPU" or device.type == 'mps' else device
    local_generator = torch.Generator(local_device).manual_seed(int(seed))
    return torch.randn(shape, device=local_device, generator=local_generator).to(device)


def randn_like(x):
    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.

    Use either randn() or manual_seed() to initialize the generator."""

    from modules.shared import opts

    if opts.randn_source == "NV":
@@ -120,6 +147,10 @@ def randn_like(x):


def randn_without_seed(shape):
    """Generate a tensor with random numbers from a normal distribution using the previously initialized genrator.

    Use either randn() or manual_seed() to initialize the generator."""

    from modules.shared import opts

    if opts.randn_source == "NV":
@@ -132,6 +163,7 @@ def randn_without_seed(shape):


def manual_seed(seed):
    """Set up a global random number generator using the specified seed."""
    from modules.shared import opts

    if opts.randn_source == "NV":
+6 −6
Original line number Diff line number Diff line
@@ -2,10 +2,8 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd

from modules import devices, processing, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared
from modules.shared import opts, state
import modules.shared as shared

SamplerData = namedtuple('SamplerData', ['name', 'constructor', 'aliases', 'options'])

@@ -85,11 +83,13 @@ class InterruptedException(BaseException):
    pass


if opts.randn_source == "CPU":
def replace_torchsde_browinan():
    import torchsde._brownian.brownian_interval

    def torchsde_randn(size, dtype, device, seed):
        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
        return devices.randn_local(seed, size).to(device=device, dtype=dtype)

    torchsde._brownian.brownian_interval._randn = torchsde_randn


replace_torchsde_browinan()