Commit 60397a78 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

Merge branch 'dev' into sdxl

parents da464a3f e5ca9877
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -6,11 +6,11 @@ function keyupEditOrder(event) {
    let target = event.originalTarget || event.composedPath()[0];
    if (!target.matches("*:is([id*='_toprow'] [id*='_prompt'], .prompt) textarea")) return;
    if (!event.altKey) return;
    event.preventDefault();

    let isLeft = event.key == "ArrowLeft";
    let isRight = event.key == "ArrowRight";
    if (!isLeft && !isRight) return;
    event.preventDefault();

    let selectionStart = target.selectionStart;
    let selectionEnd = target.selectionEnd;
+3 −2
Original line number Diff line number Diff line
@@ -54,8 +54,9 @@ def torch_gc():
        with torch.cuda.device(get_cuda_device_string()):
            torch.cuda.empty_cache()
            torch.cuda.ipc_collect()
    elif has_mps() and hasattr(torch.mps, 'empty_cache'):
        torch.mps.empty_cache()

    if has_mps():
        mac_specific.torch_mps_gc()


def enable_tf32():
+18 −0
Original line number Diff line number Diff line
import logging

import torch
import platform
from modules.sd_hijack_utils import CondFunc
from packaging import version

log = logging.getLogger(__name__)


# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
@@ -19,9 +23,23 @@ def check_for_mps() -> bool:
            return False
    else:
        return torch.backends.mps.is_available() and torch.backends.mps.is_built()


has_mps = check_for_mps()


def torch_mps_gc() -> None:
    try:
        from modules.shared import state
        if state.current_latent is not None:
            log.debug("`current_latent` is set, skipping MPS garbage collection")
            return
        from torch.mps import empty_cache
        empty_cache()
    except Exception:
        log.warning("MPS garbage collection failed", exc_info=True)


# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
    if input.device.type == 'mps':