Commit faed465a authored by brkirch's avatar brkirch Committed by AUTOMATIC1111
Browse files

MPS Upscalers Fix

Get ESRGAN, SCUNet, and SwinIR working correctly on MPS by ensuring memory is contiguous for tensor views before sending to MPS device.
parent 4c24347e
Loading
Loading
Loading
Loading
+4 −0
Original line number Diff line number Diff line
@@ -81,3 +81,7 @@ def autocast(disable=False):
        return contextlib.nullcontext()

    return torch.autocast("cuda")

# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
def mps_contiguous(input_tensor, device): return input_tensor.contiguous() if device.type == 'mps' else input_tensor
def mps_contiguous_to(input_tensor, device): return mps_contiguous(input_tensor, device).to(device)
+1 −1
Original line number Diff line number Diff line
@@ -190,7 +190,7 @@ def upscale_without_tiling(model, img):
    img = img[:, :, ::-1]
    img = np.ascontiguousarray(np.transpose(img, (2, 0, 1))) / 255
    img = torch.from_numpy(img).float()
    img = img.unsqueeze(0).to(devices.device_esrgan)
    img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_esrgan)
    with torch.no_grad():
        output = model(img)
    output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+1 −2
Original line number Diff line number Diff line
@@ -54,9 +54,8 @@ class UpscalerScuNET(modules.upscaler.Upscaler):
        img = img[:, :, ::-1]
        img = np.moveaxis(img, 2, 0) / 255
        img = torch.from_numpy(img).float()
        img = img.unsqueeze(0).to(device)
        img = devices.mps_contiguous_to(img.unsqueeze(0), device)

        img = img.to(device)
        with torch.no_grad():
            output = model(img)
        output = output.squeeze().float().cpu().clamp_(0, 1).numpy()
+1 −1
Original line number Diff line number Diff line
@@ -111,7 +111,7 @@ def upscale(
    img = img[:, :, ::-1]
    img = np.moveaxis(img, 2, 0) / 255
    img = torch.from_numpy(img).float()
    img = img.unsqueeze(0).to(devices.device_swinir)
    img = devices.mps_contiguous_to(img.unsqueeze(0), devices.device_swinir)
    with torch.no_grad(), precision_scope("cuda"):
        _, _, h_old, w_old = img.size()
        h_pad = (h_old // window_size + 1) * window_size - h_old