Commit d40e44ad authored by Deciare's avatar Deciare Committed by Deciare
Browse files

Option to use CPU for random number generation.

Makes a given manual seed generate the same images across different
platforms, independently of the GPU architecture in use.

Fixes #9613.
parent 22bcc7be
Loading
Loading
Loading
Loading
+6 −2
Original line number Diff line number Diff line
@@ -92,14 +92,18 @@ def cond_cast_float(input):


def randn(seed, shape):
    from modules.shared import opts

    torch.manual_seed(seed)
    if device.type == 'mps':
    if opts.use_cpu_randn or device.type == 'mps':
        return torch.randn(shape, device=cpu).to(device)
    return torch.randn(shape, device=device)


def randn_without_seed(shape):
    if device.type == 'mps':
    from modules.shared import opts

    if opts.use_cpu_randn or device.type == 'mps':
        return torch.randn(shape, device=cpu).to(device)
    return torch.randn(shape, device=device)

+9 −0
Original line number Diff line number Diff line
@@ -60,3 +60,12 @@ def store_latent(decoded):

class InterruptedException(BaseException):
    pass

if opts.use_cpu_randn:
    import torchsde._brownian.brownian_interval

    def torchsde_randn(size, dtype, device, seed):
        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)

    torchsde._brownian.brownian_interval._randn = torchsde_randn
+1 −1
Original line number Diff line number Diff line
@@ -190,7 +190,7 @@ class TorchHijack:
            if noise.shape == x.shape:
                return noise

        if x.device.type == 'mps':
        if opts.use_cpu_randn or x.device.type == 'mps':
            return torch.randn_like(x, device=devices.cpu).to(x.device)
        else:
            return torch.randn_like(x)
+1 −0
Original line number Diff line number Diff line
@@ -331,6 +331,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "comma_padding_backtrack": OptionInfo(20, "Increase coherency by padding from the last comma within n tokens when using more than 75 tokens", gr.Slider, {"minimum": 0, "maximum": 74, "step": 1 }),
    "CLIP_stop_at_last_layers": OptionInfo(1, "Clip skip", gr.Slider, {"minimum": 1, "maximum": 12, "step": 1}),
    "upcast_attn": OptionInfo(False, "Upcast cross attention layer to float32"),
    "use_cpu_randn": OptionInfo(False, "Use CPU for random number generation to make manual seeds generate the same image across platforms. This may change existing seeds."),
}))

options_templates.update(options_section(('compatibility', "Compatibility"), {