Commit 6074175f authored by AUTOMATIC's avatar AUTOMATIC
Browse files

add safetensors to requirements

parent f108782e
Loading
Loading
Loading
Loading
+5 −6
Original line number Original line Diff line number Diff line
@@ -5,6 +5,7 @@ import gc
from collections import namedtuple
from collections import namedtuple
import torch
import torch
import re
import re
import safetensors.torch
from omegaconf import OmegaConf
from omegaconf import OmegaConf


from ldm.util import instantiate_from_config
from ldm.util import instantiate_from_config
@@ -173,14 +174,12 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
        # load from file
        # load from file
        print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
        print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")


        if checkpoint_file.endswith(".safetensors"):
        _, extension = os.path.splitext(checkpoint_file)
            try:
        if extension.lower() == ".safetensors":
                from safetensors.torch import load_file
            pl_sd = safetensors.torch.load_file(checkpoint_file, device=shared.weight_load_location)
            except ImportError as e:
                raise ImportError(f"The model is in safetensors format and it is not installed, use `pip install safetensors`: {e}")
            pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
        else:
        else:
            pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
            pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)

        if "global_step" in pl_sd:
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
            print(f"Global Step: {pl_sd['global_step']}")


+1 −0
Original line number Original line Diff line number Diff line
@@ -29,3 +29,4 @@ lark
inflection
inflection
GitPython
GitPython
torchsde
torchsde
safetensors
+1 −0
Original line number Original line Diff line number Diff line
@@ -26,3 +26,4 @@ lark==1.1.2
inflection==0.5.1
inflection==0.5.1
GitPython==3.1.27
GitPython==3.1.27
torchsde==0.2.5
torchsde==0.2.5
safetensors==0.2.5