Commit 016554e4 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add --medvram-sdxl

parent bb7dd7b6
Loading
Loading
Loading
Loading
+1 −0
Original line number Diff line number Diff line
@@ -35,6 +35,7 @@ parser.add_argument("--hypernetwork-dir", type=str, default=os.path.join(models_
parser.add_argument("--localizations-dir", type=str, default=os.path.join(script_path, 'localizations'), help="localizations directory")
parser.add_argument("--allow-code", action='store_true', help="allow custom script execution from webui")
parser.add_argument("--medvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a little speed for low VRM usage")
parser.add_argument("--medvram-sdxl", action='store_true', help="enable --medvram optimization just for SDXL models")
parser.add_argument("--lowvram", action='store_true', help="enable stable diffusion model optimizations for sacrificing a lot of speed for very low VRM usage")
parser.add_argument("--lowram", action='store_true', help="load stable diffusion checkpoint weights to VRAM instead of RAM")
parser.add_argument("--always-batch-cond-uncond", action='store_true', help="does not do anything")
+2 −3
Original line number Diff line number Diff line
@@ -186,7 +186,6 @@ class InterrogateModels:
        res = ""
        shared.state.begin(job="interrogate")
        try:
            if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
            lowvram.send_everything_to_cpu()
            devices.torch_gc()

+16 −2
Original line number Diff line number Diff line
import torch
from modules import devices
from modules import devices, shared

module_in_gpu = None
cpu = torch.device("cpu")
@@ -14,6 +14,20 @@ def send_everything_to_cpu():
    module_in_gpu = None


def is_needed(sd_model):
    return shared.cmd_opts.lowvram or shared.cmd_opts.medvram or shared.cmd_opts.medvram_sdxl and hasattr(sd_model, 'conditioner')


def apply(sd_model):
    enable = is_needed(sd_model)
    shared.parallel_processing_allowed = not enable

    if enable:
        setup_for_low_vram(sd_model, not shared.cmd_opts.lowvram)
    else:
        sd_model.lowvram = False


def setup_for_low_vram(sd_model, use_medvram):
    if getattr(sd_model, 'lowvram', False):
        return
@@ -130,4 +144,4 @@ def setup_for_low_vram(sd_model, use_medvram):


def is_enabled(sd_model):
    return getattr(sd_model, 'lowvram', False)
    return sd_model.lowvram
+8 −8
Original line number Diff line number Diff line
@@ -517,7 +517,7 @@ def get_empty_cond(sd_model):


def send_model_to_cpu(m):
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
    if m.lowvram:
        lowvram.send_everything_to_cpu()
    else:
        m.to(devices.cpu)
@@ -525,17 +525,17 @@ def send_model_to_cpu(m):
    devices.torch_gc()


def model_target_device():
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
def model_target_device(m):
    if lowvram.is_needed(m):
        return devices.cpu
    else:
        return devices.device


def send_model_to_device(m):
    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(m, shared.cmd_opts.medvram)
    else:
    lowvram.apply(m)

    if not m.lowvram:
        m.to(shared.device)


@@ -601,7 +601,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None):
            '': torch.float16,
        }

    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(), weight_dtype_conversion=weight_dtype_conversion):
    with sd_disable_initialization.LoadStateDictOnMeta(state_dict, device=model_target_device(sd_model), weight_dtype_conversion=weight_dtype_conversion):
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
    timer.record("load weights from state dict")

@@ -743,7 +743,7 @@ def reload_model_weights(sd_model=None, info=None):
        script_callbacks.model_loaded_callback(sd_model)
        timer.record("script callbacks")

        if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
        if not sd_model.lowvram:
            sd_model.to(devices.device)
            timer.record("move model to device")

+1 −1
Original line number Diff line number Diff line
@@ -47,7 +47,7 @@ def apply_unet(option=None):
    if current_unet_option is None:
        current_unet = None

        if not (shared.cmd_opts.lowvram or shared.cmd_opts.medvram):
        if not shared.sd_model.lowvram:
            shared.sd_model.model.diffusion_model.to(devices.device)

        return
Loading