Commit 24129368 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

send tensors to the correct device when loading from safetensors file with...

send tensors to the correct device when loading from safetensors file with memmap disabled for #11260
parent 14196548
Loading
Loading
Loading
Loading
+3 −1
Original line number Diff line number Diff line
@@ -246,11 +246,13 @@ def read_metadata_from_safetensors(filename):
def read_state_dict(checkpoint_file, print_global_state=False, map_location=None):
    _, extension = os.path.splitext(checkpoint_file)
    if extension.lower() == ".safetensors":
        if not shared.opts.disable_mmap_load_safetensors:
        device = map_location or shared.weight_load_location or devices.get_optimal_device_name()

        if not shared.opts.disable_mmap_load_safetensors:
            pl_sd = safetensors.torch.load_file(checkpoint_file, device=device)
        else:
            pl_sd = safetensors.torch.load(open(checkpoint_file, 'rb').read())
            pl_sd = {k: v.to(device) for k, v in pl_sd.items()}
    else:
        pl_sd = torch.load(checkpoint_file, map_location=map_location or shared.weight_load_location)

+1 −1
Original line number Diff line number Diff line
@@ -376,7 +376,7 @@ options_templates.update(options_section(('system', "System"), {
    "multiple_tqdm": OptionInfo(True, "Add a second progress bar to the console that shows progress for an entire job."),
    "print_hypernet_extra": OptionInfo(False, "Print extra hypernetwork information to console."),
    "list_hidden_files": OptionInfo(True, "Load models/files in hidden directories").info("directory is hidden if its name starts with \".\""),
    "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files (fixes very slow loading speed in some cases)."),
    "disable_mmap_load_safetensors": OptionInfo(False, "Disable memmapping for loading .safetensors files.").info("fixes very slow loading speed in some cases"),
}))

options_templates.update(options_section(('training', "Training"), {