Commit ac90cf38 authored by Tim Patton's avatar Tim Patton
Browse files

safetensors optional for now

parent 210cb4c1
Loading
Loading
Loading
Loading
+8 −1
Original line number Diff line number Diff line
@@ -4,7 +4,6 @@ import sys
import gc
from collections import namedtuple
import torch
from safetensors.torch import load_file, save_file
import re
from omegaconf import OmegaConf

@@ -149,6 +148,10 @@ def torch_load(model_filename, model_info, map_override=None):
        # safely load weights
        # TODO: safetensors supports zero copy fast load to gpu, see issue #684.  
        # GPU only for now, see https://github.com/huggingface/safetensors/issues/95
        try:
            from safetensors.torch import load_file
        except ImportError as e:
            raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
        return load_file(model_filename, device='cuda')
    else:
        return torch.load(model_filename, map_location=map_override)
@@ -157,6 +160,10 @@ def torch_save(model, output_filename):
    basename, exttype = os.path.splitext(output_filename)
    if(checkpoint_types[exttype] == 'safetensors'):
        # [=====  >] Reticulating brines...
        try:
            from safetensors.torch import save_file
        except ImportError as e:
            raise ImportError(f"Export as safetensors selected, yet it is not installed, use `pip install safetensors`: {e}")
        save_file(model, output_filename, metadata={"format": "pt"})
    else:
        torch.save(model, output_filename)
+0 −1
Original line number Diff line number Diff line
@@ -28,4 +28,3 @@ kornia
lark
inflection
GitPython
safetensors