Commit ac7ecd2d authored by Tim Patton's avatar Tim Patton
Browse files

Label and load SD .safetensors model files

parent 47a44c7e
Loading
Loading
Loading
Loading
+1 −0
Original line number Original line Diff line number Diff line
@@ -84,6 +84,7 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- API
- API
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. 
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. 
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- via extension: [Aesthetic Gradients](https://github.com/AUTOMATIC1111/stable-diffusion-webui-aesthetic-gradients), a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))
- Can use safetensors to safely load model files without python pickle


## Where are Aesthetic Gradients?!?!
## Where are Aesthetic Gradients?!?!
Aesthetic Gradients are now an extension. You can install it using git:
Aesthetic Gradients are now an extension. You can install it using git:
+1 −0
Original line number Original line Diff line number Diff line
@@ -82,6 +82,7 @@ def cleanup_models():
    src_path = models_path
    src_path = models_path
    dest_path = os.path.join(models_path, "Stable-diffusion")
    dest_path = os.path.join(models_path, "Stable-diffusion")
    move_files(src_path, dest_path, ".ckpt")
    move_files(src_path, dest_path, ".ckpt")
    move_files(src_path, dest_path, ".safetensors")
    src_path = os.path.join(root_path, "ESRGAN")
    src_path = os.path.join(root_path, "ESRGAN")
    dest_path = os.path.join(models_path, "ESRGAN")
    dest_path = os.path.join(models_path, "ESRGAN")
    move_files(src_path, dest_path)
    move_files(src_path, dest_path)
+16 −8
Original line number Original line Diff line number Diff line
@@ -4,6 +4,7 @@ import sys
import gc
import gc
from collections import namedtuple
from collections import namedtuple
import torch
import torch
from safetensors.torch import load_file
import re
import re
from omegaconf import OmegaConf
from omegaconf import OmegaConf


@@ -16,9 +17,10 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion"
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
model_path = os.path.abspath(os.path.join(models_path, model_dir))


CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config', 'exttype'])
checkpoints_list = {}
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()
checkpoints_loaded = collections.OrderedDict()
checkpoint_types = {'.ckpt':'pickle','.safetensors':'safetensors'}


try:
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.
@@ -45,7 +47,7 @@ def checkpoint_tiles():


def list_models():
def list_models():
    checkpoints_list.clear()
    checkpoints_list.clear()
    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt"])
    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt",".safetensors"])


    def modeltitle(path, shorthash):
    def modeltitle(path, shorthash):
        abspath = os.path.abspath(path)
        abspath = os.path.abspath(path)
@@ -60,15 +62,15 @@ def list_models():
        if name.startswith("\\") or name.startswith("/"):
        if name.startswith("\\") or name.startswith("/"):
            name = name[1:]
            name = name[1:]


        shortname = os.path.splitext(name.replace("/", "_").replace("\\", "_"))[0]
        shortname, ext = os.path.splitext(name.replace("/", "_").replace("\\", "_"))


        return f'{name} [{shorthash}]', shortname
        return f'{name} [{checkpoint_types[ext]}] [{shorthash}]', shortname


    cmd_ckpt = shared.cmd_opts.ckpt
    cmd_ckpt = shared.cmd_opts.ckpt
    if os.path.exists(cmd_ckpt):
    if os.path.exists(cmd_ckpt):
        h = model_hash(cmd_ckpt)
        h = model_hash(cmd_ckpt)
        title, short_model_name = modeltitle(cmd_ckpt, h)
        title, short_model_name = modeltitle(cmd_ckpt, h)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config, '')
        shared.opts.data['sd_model_checkpoint'] = title
        shared.opts.data['sd_model_checkpoint'] = title
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
        print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
        print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -76,12 +78,12 @@ def list_models():
        h = model_hash(filename)
        h = model_hash(filename)
        title, short_model_name = modeltitle(filename, h)
        title, short_model_name = modeltitle(filename, h)


        basename, _ = os.path.splitext(filename)
        basename, ext = os.path.splitext(filename)
        config = basename + ".yaml"
        config = basename + ".yaml"
        if not os.path.exists(config):
        if not os.path.exists(config):
            config = shared.cmd_opts.config
            config = shared.cmd_opts.config


        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config, ext)




def get_closet_checkpoint_match(searchString):
def get_closet_checkpoint_match(searchString):
@@ -173,7 +175,13 @@ def load_model_weights(model, checkpoint_info, vae_file="auto"):
        # load from file
        # load from file
        print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")
        print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")


        if(checkpoint_types[checkpoint_info.exttype] == 'safetensors'):
            # safely load weights
            # TODO: safetensors supports zero copy fast load to gpu, see issue #684
            pl_sd = load_file(checkpoint_file, device=shared.weight_load_location)
        else:
            pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)
            pl_sd = torch.load(checkpoint_file, map_location=shared.weight_load_location)

        if "global_step" in pl_sd:
        if "global_step" in pl_sd:
            print(f"Global Step: {pl_sd['global_step']}")
            print(f"Global Step: {pl_sd['global_step']}")


+1 −0
Original line number Original line Diff line number Diff line
@@ -28,3 +28,4 @@ kornia
lark
lark
inflection
inflection
GitPython
GitPython
safetensors