Commit 043d2edc authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Better naming

parent f383af27
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -128,7 +128,7 @@ def manual_cast_forward(self, *args, **kwargs):


@contextlib.contextmanager
def manual_autocast():
def manual_cast():
    for module_type in patch_module_list:
        org_forward = module_type.forward
        module_type.forward = manual_cast_forward
@@ -148,10 +148,10 @@ def autocast(disable=False):
        return torch.autocast("cpu", dtype=torch.bfloat16, enabled=True)

    if fp8 and (dtype == torch.float32 or shared.cmd_opts.precision == "full" or cuda_no_autocast()):
        return manual_autocast()
        return manual_cast()

    if has_mps() and shared.cmd_opts.precision != "full":
        return manual_autocast()
        return manual_cast()

    if dtype == torch.float32 or shared.cmd_opts.precision == "full":
        return contextlib.nullcontext()