Commit 76010a51 authored by wangqiuwen's avatar wangqiuwen
Browse files

up

parent 8e355fbd
Loading
Loading
Loading
Loading
+5 −6
Original line number Diff line number Diff line
import collections
import copy
import os.path
import sys
import gc
@@ -309,8 +310,6 @@ def get_checkpoint_state_dict(checkpoint_info: CheckpointInfo, timer):
    if checkpoint_info in checkpoints_loaded:
        # use checkpoint cache
        print(f"Loading weights [{sd_model_hash}] from cache")
        # move to end as latest
        checkpoints_loaded.move_to_end(checkpoint_info)
        return checkpoints_loaded[checkpoint_info]

    print(f"Loading weights [{sd_model_hash}] from {checkpoint_info.filename}")
@@ -352,12 +351,12 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, state_dict, timer
    if model.is_sdxl:
        sd_models_xl.extend_sdxl(model)

    model.load_state_dict(state_dict, strict=False)
    timer.record("apply weights to model")

    if shared.opts.sd_checkpoint_cache > 0:
        # cache newly loaded model
        checkpoints_loaded[checkpoint_info] = state_dict
        checkpoints_loaded[checkpoint_info] = copy.deepcopy(state_dict)

    model.load_state_dict(state_dict, strict=False)
    timer.record("apply weights to model")

    del state_dict