Commit 8faac8b9 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

run basic torch calculation at startup in parallel to reduce the performance...

run basic torch calculation at startup in parallel to reduce the performance impact of first generation
parent 1f318292
Loading
Loading
Loading
Loading
+18 −0
Original line number Diff line number Diff line
import sys
import contextlib
from functools import lru_cache

import torch
from modules import errors

@@ -154,3 +156,19 @@ def test_for_nans(x, where):
    message += " Use --disable-nan-check commandline argument to disable this check."

    raise NansException(message)


@lru_cache
def first_time_calculation():
    """
    just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
    spends about 2.7 seconds doing that, at least wih NVidia.
    """

    x = torch.zeros((1, 1)).to(device, dtype)
    linear = torch.nn.Linear(1, 1).to(device, dtype)
    linear(x)

    x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
    conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
    conv2d(x)
+3 −1
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ import logging

logging.getLogger("xformers").addFilter(lambda record: 'A matching Triton is not available' not in record.getMessage())

from modules import paths, timer, import_hook, errors  # noqa: F401
from modules import paths, timer, import_hook, errors, devices  # noqa: F401

startup_timer = timer.Timer()

@@ -295,6 +295,8 @@ def initialize_rest(*, reload_script_modules=False):
    # (when reloading, this does nothing)
    Thread(target=lambda: shared.sd_model).start()

    Thread(target=devices.first_time_calculation).start()

    shared.reload_hypernetworks()
    startup_timer.record("reload hypernetworks")