Unverified Commit a2feaa95 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #5194 from brkirch/autocast-and-mps-randn-fixes

Use devices.autocast() and fix MPS randn issues
parents c7af6721 0fddb4a1
Loading
Loading
Loading
Loading
+3 −12
Original line number Diff line number Diff line
@@ -66,24 +66,15 @@ dtype_vae = torch.float16


def randn(seed, shape):
    # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
    if device.type == 'mps':
        generator = torch.Generator(device=cpu)
        generator.manual_seed(seed)
        noise = torch.randn(shape, generator=generator, device=cpu).to(device)
        return noise

    torch.manual_seed(seed)
    if device.type == 'mps':
        return torch.randn(shape, device=cpu).to(device)
    return torch.randn(shape, device=device)


def randn_without_seed(shape):
    # Pytorch currently doesn't handle setting randomness correctly when the metal backend is used.
    if device.type == 'mps':
        generator = torch.Generator(device=cpu)
        noise = torch.randn(shape, generator=generator, device=cpu).to(device)
        return noise

        return torch.randn(shape, device=cpu).to(device)
    return torch.randn(shape, device=device)


+1 −1
Original line number Diff line number Diff line
@@ -495,7 +495,7 @@ def train_hypernetwork(hypernetwork_name, learn_rate, batch_size, gradient_step,
                if shared.state.interrupted:
                    break

                with torch.autocast("cuda"):
                with devices.autocast():
                    x = batch.latent_sample.to(devices.device, non_blocking=pin_memory)
                    if tag_drop_out != 0 or shuffle_tags:
                        shared.sd_model.cond_stage_model.to(devices.device)
+1 −2
Original line number Diff line number Diff line
@@ -148,8 +148,7 @@ class InterrogateModels:

            clip_image = self.clip_preprocess(pil_image).unsqueeze(0).type(self.dtype).to(devices.device_interrogate)

            precision_scope = torch.autocast if shared.cmd_opts.precision == "autocast" else contextlib.nullcontext
            with torch.no_grad(), precision_scope("cuda"):
            with torch.no_grad(), devices.autocast():
                image_features = self.clip_model.encode_image(clip_image).type(self.dtype)

                image_features /= image_features.norm(dim=-1, keepdim=True)
+1 −5
Original line number Diff line number Diff line
@@ -183,11 +183,7 @@ def register_buffer(self, name, attr):

    if type(attr) == torch.Tensor:
        if attr.device != devices.device:

            if devices.has_mps():
                attr = attr.to(device="mps", dtype=torch.float32)
            else:
                attr = attr.to(devices.device)
            attr = attr.to(device=devices.device, dtype=(torch.float32 if devices.device.type == 'mps' else None))

    setattr(self, name, attr)

+19 −3
Original line number Diff line number Diff line
@@ -6,6 +6,7 @@ import tqdm
from PIL import Image
import inspect
import k_diffusion.sampling
import torchsde._brownian.brownian_interval
import ldm.models.diffusion.ddim
import ldm.models.diffusion.plms
from modules import prompt_parser, devices, processing, images
@@ -364,9 +365,25 @@ class TorchHijack:
            if noise.shape == x.shape:
                return noise

        if x.device.type == 'mps':
            return torch.randn_like(x, device=devices.cpu).to(x.device)
        else:
            return torch.randn_like(x)


# MPS fix for randn in torchsde
def torchsde_randn(size, dtype, device, seed):
    if device.type == 'mps':
        generator = torch.Generator(devices.cpu).manual_seed(int(seed))
        return torch.randn(size, dtype=dtype, device=devices.cpu, generator=generator).to(device)
    else:
        generator = torch.Generator(device).manual_seed(int(seed))
        return torch.randn(size, dtype=dtype, device=device, generator=generator)


torchsde._brownian.brownian_interval._randn = torchsde_randn


class KDiffusionSampler:
    def __init__(self, funcname, sd_model):
        denoiser = k_diffusion.external.CompVisVDenoiser if sd_model.parameterization == "v" else k_diffusion.external.CompVisDenoiser
@@ -415,8 +432,7 @@ class KDiffusionSampler:
        self.model_wrap.step = 0
        self.eta = p.eta or opts.eta_ancestral

        if self.sampler_noises is not None:
            k_diffusion.sampling.torch = TorchHijack(self.sampler_noises)
        k_diffusion.sampling.torch = TorchHijack(self.sampler_noises if self.sampler_noises is not None else [])

        extra_params_kwargs = {}
        for param_name in self.extra_params:
Loading