Commit b50ff4f4 authored by Josh Watzman's avatar Josh Watzman
Browse files

Reduce peak memory usage when changing models

A few tweaks to reduce peak memory usage, the biggest being that if we
aren't using the checkpoint cache, we shouldn't duplicate the model
state dict just to immediately throw it away.

On my machine with 16GB of RAM, this change means I can typically change
models, whereas before it would typically OOM.
parent 737eb28f
Loading
Loading
Loading
Loading
+7 −4
Original line number Diff line number Diff line
@@ -170,7 +170,9 @@ def load_model_weights(model, checkpoint_info):
            print(f"Global Step: {pl_sd['global_step']}")

        sd = get_state_dict_from_checkpoint(pl_sd)
        missing, extra = model.load_state_dict(sd, strict=False)
        del pl_sd
        model.load_state_dict(sd, strict=False)
        del sd

        if shared.cmd_opts.opt_channelslast:
            model.to(memory_format=torch.channels_last)
@@ -194,6 +196,7 @@ def load_model_weights(model, checkpoint_info):

        model.first_stage_model.to(devices.dtype_vae)

        if shared.opts.sd_checkpoint_cache > 0:
            checkpoints_loaded[checkpoint_info] = model.state_dict().copy()
            while len(checkpoints_loaded) > shared.opts.sd_checkpoint_cache:
                checkpoints_loaded.popitem(last=False)  # LRU