Unverified Commit cb9571e3 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #9734 from deciare/cpu-randn

Option to make images generated from a given manual seed consistent across CUDA and MPS devices
parents 09069918 d40e44ad
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -92,14 +92,18 @@ def cond_cast_float(input):


def randn(seed, shape):
    from modules.shared import opts

    torch.manual_seed(seed)
    if device.type == 'mps':
    if opts.use_cpu_randn or device.type == 'mps':
        return torch.randn(shape, device=cpu).to(device)
    return torch.randn(shape, device=device)


def randn_without_seed(shape):
    if device.type == 'mps':
    from modules.shared import opts

    if opts.use_cpu_randn or device.type == 'mps':
        return torch.randn(shape, device=cpu).to(device)
    return torch.randn(shape, device=device)

+9 −0
Original line number Diff line number Diff line
@@ -60,3 +60,12 @@ def store_latent(decoded):

class InterruptedException(BaseException):
    pass

if opts.use_cpu_randn:
    import torchsde._brownian.brownian_interval

    def torchsde_randn(size, dtype, device, seed):
        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)

    torchsde._brownian.brownian_interval._randn = torchsde_randn
+1 −1
Original line number Diff line number Diff line
@@ -190,7 +190,7 @@ class TorchHijack:
            if noise.shape == x.shape:
                return noise

        if x.device.type == 'mps':
        if opts.use_cpu_randn or x.device.type == 'mps':
            return torch.randn_like(x, device=devices.cpu).to(x.device)
        else:
            return torch.randn_like(x)
+1 −0
Original line number Diff line number Diff line
@@ -334,6 +334,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
    "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
    "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
    "use_cpu_randn": OptionInfo(False, "Use CPU for random number generation to make manual seeds generate the same image across platforms. This may change existing seeds."),
}))

options_templates.update(options_section(('compatibility', "Compatibility"), {