Commit a5bbcd21 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

fix bug with "Ignore selected VAE for..." option completely disabling VAE election

rework VAE resolving code to be more simple
parent 69781031
Loading
Loading
Loading
Loading
+3 −3
Original line number Diff line number Diff line
@@ -224,7 +224,7 @@ def read_state_dict(checkpoint_file, print_global_state=False, map_location=None
    return sd


def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):
def load_model_weights(model, checkpoint_info: CheckpointInfo):
    sd_model_hash = checkpoint_info.calculate_shorthash()

    cache_enabled = shared.opts.sd_checkpoint_cache > 0
@@ -277,8 +277,8 @@ def load_model_weights(model, checkpoint_info: CheckpointInfo, vae_file="auto"):

    sd_vae.delete_base_vae()
    sd_vae.clear_loaded_vae()
    vae_file = sd_vae.resolve_vae(checkpoint_info.filename, vae_file=vae_file)
    sd_vae.load_vae(model, vae_file)
    vae_file, vae_source = sd_vae.resolve_vae(checkpoint_info.filename)
    sd_vae.load_vae(model, vae_file, vae_source)


def enable_midas_autodownload():
+77 −117
Original line number Diff line number Diff line
@@ -9,23 +9,9 @@ import glob
from copy import deepcopy


model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))
vae_dir = "VAE"
vae_path = os.path.abspath(os.path.join(models_path, vae_dir))


vae_path = os.path.abspath(os.path.join(models_path, "VAE"))
vae_ignore_keys = {"model_ema.decay", "model_ema.num_updates"}


default_vae_dict = {"auto": "auto", "None": None, None: None}
default_vae_list = ["auto", "None"]


default_vae_values = [default_vae_dict[x] for x in default_vae_list]
vae_dict = dict(default_vae_dict)
vae_list = list(default_vae_list)
first_load = True
vae_dict = {}


base_vae = None
@@ -64,100 +50,69 @@ def restore_base_vae(model):


def get_filename(filepath):
    return os.path.splitext(os.path.basename(filepath))[0]


def refresh_vae_list(vae_path=vae_path, model_path=model_path):
    global vae_dict, vae_list
    res = {}
    candidates = [
        *glob.iglob(os.path.join(model_path, '**/*.vae.ckpt'), recursive=True),
        *glob.iglob(os.path.join(model_path, '**/*.vae.pt'), recursive=True),
        *glob.iglob(os.path.join(model_path, '**/*.vae.safetensors'), recursive=True),
        *glob.iglob(os.path.join(vae_path, '**/*.ckpt'), recursive=True),
        *glob.iglob(os.path.join(vae_path, '**/*.pt'), recursive=True),
        *glob.iglob(os.path.join(vae_path, '**/*.safetensors'), recursive=True),
    return os.path.basename(filepath)


def refresh_vae_list():
    vae_dict.clear()

    paths = [
        os.path.join(sd_models.model_path, '**/*.vae.ckpt'),
        os.path.join(sd_models.model_path, '**/*.vae.pt'),
        os.path.join(sd_models.model_path, '**/*.vae.safetensors'),
        os.path.join(vae_path, '**/*.ckpt'),
        os.path.join(vae_path, '**/*.pt'),
        os.path.join(vae_path, '**/*.safetensors'),
    ]

    if shared.cmd_opts.ckpt_dir is not None and os.path.isdir(shared.cmd_opts.ckpt_dir):
        paths += [
            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.ckpt'),
            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.pt'),
            os.path.join(shared.cmd_opts.ckpt_dir, '**/*.vae.safetensors'),
        ]
    if shared.cmd_opts.vae_path is not None and os.path.isfile(shared.cmd_opts.vae_path):
        candidates.append(shared.cmd_opts.vae_path)

    candidates = []
    for path in paths:
        candidates += glob.iglob(path, recursive=True)

    for filepath in candidates:
        name = get_filename(filepath)
        res[name] = filepath
    vae_list.clear()
    vae_list.extend(default_vae_list)
    vae_list.extend(list(res.keys()))
    vae_dict.clear()
    vae_dict.update(res)
    vae_dict.update(default_vae_dict)
    return vae_list


def get_vae_from_settings(vae_file="auto"):
    # else, we load from settings, if not set to be default
    if vae_file == "auto" and shared.opts.sd_vae is not None:
        # if saved VAE settings isn't recognized, fallback to auto
        vae_file = vae_dict.get(shared.opts.sd_vae, "auto")
        # if VAE selected but not found, fallback to auto
        if vae_file not in default_vae_values and not os.path.isfile(vae_file):
            vae_file = "auto"
            print(f"Selected VAE doesn't exist: {vae_file}")
    return vae_file


def resolve_vae(checkpoint_file=None, vae_file="auto"):
    global first_load, vae_dict, vae_list

    # if vae_file argument is provided, it takes priority, but not saved
    if vae_file and vae_file not in default_vae_list:
        if not os.path.isfile(vae_file):
            print(f"VAE provided as function argument doesn't exist: {vae_file}")
            vae_file = "auto"
    # for the first load, if vae-path is provided, it takes priority, saved, and failure is reported
    if first_load and shared.cmd_opts.vae_path is not None:
        if os.path.isfile(shared.cmd_opts.vae_path):
            vae_file = shared.cmd_opts.vae_path
            shared.opts.data['sd_vae'] = get_filename(vae_file)
        else:
            print(f"VAE provided as command line argument doesn't exist: {vae_file}")
    # fallback to selector in settings, if vae selector not set to act as default fallback
    if not shared.opts.sd_vae_as_default:
        vae_file = get_vae_from_settings(vae_file)
    # vae-path cmd arg takes priority for auto
    if vae_file == "auto" and shared.cmd_opts.vae_path is not None:
        if os.path.isfile(shared.cmd_opts.vae_path):
            vae_file = shared.cmd_opts.vae_path
            print(f"Using VAE provided as command line argument: {vae_file}")
    # if still not found, try look for ".vae.pt" beside model
    model_path = os.path.splitext(checkpoint_file)[0]
    if vae_file == "auto":
        vae_file_try = model_path + ".vae.pt"
        if os.path.isfile(vae_file_try):
            vae_file = vae_file_try
            print(f"Using VAE found similar to selected model: {vae_file}")
    # if still not found, try look for ".vae.ckpt" beside model
    if vae_file == "auto":
        vae_file_try = model_path + ".vae.ckpt"
        if os.path.isfile(vae_file_try):
            vae_file = vae_file_try
            print(f"Using VAE found similar to selected model: {vae_file}")
    # if still not found, try look for ".vae.safetensors" beside model
    if vae_file == "auto":
        vae_file_try = model_path + ".vae.safetensors"
        if os.path.isfile(vae_file_try):
            vae_file = vae_file_try
            print(f"Using VAE found similar to selected model: {vae_file}")
    # No more fallbacks for auto
    if vae_file == "auto":
        vae_file = None
    # Last check, just because
    if vae_file and not os.path.exists(vae_file):
        vae_file = None

    return vae_file


def load_vae(model, vae_file=None):
    global first_load, vae_dict, vae_list, loaded_vae_file
        vae_dict[name] = filepath


def find_vae_near_checkpoint(checkpoint_file):
    checkpoint_path = os.path.splitext(checkpoint_file)[0]
    for vae_location in [checkpoint_path + ".vae.pt", checkpoint_path + ".vae.ckpt", checkpoint_path + ".vae.safetensors"]:
        if os.path.isfile(vae_location):
            return vae_location

    return None


def resolve_vae(checkpoint_file):
    if shared.cmd_opts.vae_path is not None:
        return shared.cmd_opts.vae_path, 'from commandline argument'

    vae_near_checkpoint = find_vae_near_checkpoint(checkpoint_file)
    if vae_near_checkpoint is not None and (shared.opts.sd_vae_as_default or shared.opts.sd_vae == "auto"):
        return vae_near_checkpoint, 'found near the checkpoint'

    if shared.opts.sd_vae == "None":
        return None, None

    vae_from_options = vae_dict.get(shared.opts.sd_vae, None)
    if vae_from_options is not None:
        return vae_from_options, 'specified in settings'

    if shared.opts.sd_vae != "Automatic":
        print(f"Couldn't find VAE named {shared.opts.sd_vae}; using None instead")

    return None, None


def load_vae(model, vae_file=None, vae_source="from unknown source"):
    global vae_dict, loaded_vae_file
    # save_settings = False

    cache_enabled = shared.opts.sd_vae_checkpoint_cache > 0
@@ -165,12 +120,12 @@ def load_vae(model, vae_file=None):
    if vae_file:
        if cache_enabled and vae_file in checkpoints_loaded:
            # use vae checkpoint cache
            print(f"Loading VAE weights [{get_filename(vae_file)}] from cache")
            print(f"Loading VAE weights {vae_source}: cached {get_filename(vae_file)}")
            store_base_vae(model)
            _load_vae_dict(model, checkpoints_loaded[vae_file])
        else:
            assert os.path.isfile(vae_file), f"VAE file doesn't exist: {vae_file}"
            print(f"Loading VAE weights from: {vae_file}")
            assert os.path.isfile(vae_file), f"VAE {vae_source} doesn't exist: {vae_file}"
            print(f"Loading VAE weights {vae_source}: {vae_file}")
            store_base_vae(model)

            vae_ckpt = sd_models.read_state_dict(vae_file, map_location=shared.weight_load_location)
@@ -191,14 +146,12 @@ def load_vae(model, vae_file=None):
        vae_opt = get_filename(vae_file)
        if vae_opt not in vae_dict:
            vae_dict[vae_opt] = vae_file
            vae_list.append(vae_opt)

    elif loaded_vae_file:
        restore_base_vae(model)

    loaded_vae_file = vae_file

    first_load = False


# don't call this from outside
def _load_vae_dict(model, vae_dict_1):
@@ -211,7 +164,10 @@ def clear_loaded_vae():
    loaded_vae_file = None


def reload_vae_weights(sd_model=None, vae_file="auto"):
unspecified = object()


def reload_vae_weights(sd_model=None, vae_file=unspecified):
    from modules import lowvram, devices, sd_hijack

    if not sd_model:
@@ -219,7 +175,11 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):

    checkpoint_info = sd_model.sd_checkpoint_info
    checkpoint_file = checkpoint_info.filename
    vae_file = resolve_vae(checkpoint_file, vae_file=vae_file)

    if vae_file == unspecified:
        vae_file, vae_source = resolve_vae(checkpoint_file)
    else:
        vae_source = "from function argument"

    if loaded_vae_file == vae_file:
        return
@@ -231,7 +191,7 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):

    sd_hijack.model_hijack.undo_hijack(sd_model)

    load_vae(sd_model, vae_file)
    load_vae(sd_model, vae_file, vae_source)

    sd_hijack.model_hijack.hijack(sd_model)
    script_callbacks.model_loaded_callback(sd_model)
@@ -239,5 +199,5 @@ def reload_vae_weights(sd_model=None, vae_file="auto"):
    if not shared.cmd_opts.lowvram and not shared.cmd_opts.medvram:
        sd_model.to(devices.device)

    print("VAE Weights loaded.")
    print("VAE weights loaded.")
    return sd_model
+2 −2
Original line number Diff line number Diff line
@@ -83,7 +83,7 @@ parser.add_argument("--theme", type=str, help="launches the UI with light or dar
parser.add_argument("--use-textbox-seed", action='store_true', help="use textbox for seeds in UI (no up/down, but possible to input long seeds)", default=False)
parser.add_argument("--disable-console-progressbars", action='store_true', help="do not output progressbars to console", default=False)
parser.add_argument("--enable-console-prompts", action='store_true', help="print prompts to console when generating with txt2img and img2img", default=False)
parser.add_argument('--vae-path', type=str, help='Path to Variational Autoencoders model', default=None)
parser.add_argument('--vae-path', type=str, help='Checkpoint to use as VAE; setting this argument disables all settings related to VAE', default=None)
parser.add_argument("--disable-safe-unpickle", action='store_true', help="disable checking pytorch models for malicious code", default=False)
parser.add_argument("--api", action='store_true', help="use api=True to launch the API together with the webui (use --nowebui instead for only the API)")
parser.add_argument("--api-auth", type=str, help='Set authentication for API like "username:password"; or comma-delimit multiple like "u1:p1,u2:p2,u3:p3"', default=None)
@@ -383,7 +383,7 @@ options_templates.update(options_section(('sd', "Stable Diffusion"), {
    "sd_model_checkpoint": OptionInfo(None, "Stable Diffusion checkpoint", gr.Dropdown, lambda: {"choices": list_checkpoint_tiles()}, refresh=refresh_checkpoints),
    "sd_checkpoint_cache": OptionInfo(0, "Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
    "sd_vae_checkpoint_cache": OptionInfo(0, "VAE Checkpoints to cache in RAM", gr.Slider, {"minimum": 0, "maximum": 10, "step": 1}),
    "sd_vae": OptionInfo("auto", "SD VAE", gr.Dropdown, lambda: {"choices": sd_vae.vae_list}, refresh=sd_vae.refresh_vae_list),
    "sd_vae": OptionInfo("Automatic", "SD VAE", gr.Dropdown, lambda: {"choices": ["Automatic", "None"] + list(sd_vae.vae_dict)}, refresh=sd_vae.refresh_vae_list),
    "sd_vae_as_default": OptionInfo(False, "Ignore selected VAE for stable diffusion checkpoints that have their own .vae.pt next to them"),
    "sd_hypernetwork": OptionInfo("None", "Hypernetwork", gr.Dropdown, lambda: {"choices": ["None"] + [x for x in hypernetworks.keys()]}, refresh=reload_hypernetworks),
    "sd_hypernetwork_strength": OptionInfo(1.0, "Hypernetwork strength", gr.Slider, {"minimum": 0.0, "maximum": 1.0, "step": 0.001}),
+13 −14
Original line number Diff line number Diff line
@@ -125,24 +125,21 @@ def apply_upscale_latent_space(p, x, xs):


def find_vae(name: str):
    if name.lower() in ['auto', 'none']:
        return name
    if name.lower() in ['auto', 'automatic']:
        return modules.sd_vae.unspecified
    if name.lower() == 'none':
        return None
    else:
        vae_path = os.path.abspath(os.path.join(paths.models_path, 'VAE'))
        found = glob.glob(os.path.join(vae_path, f'**/{name}.*pt'), recursive=True)
        if found:
            return found[0]
        choices = [x for x in sorted(modules.sd_vae.vae_dict, key=lambda x: len(x)) if name.lower().strip() in x.lower()]
        if len(choices) == 0:
            print(f"No VAE found for {name}; using automatic")
            return modules.sd_vae.unspecified
        else:
            return 'auto'
            return modules.sd_vae.vae_dict[choices[0]]


def apply_vae(p, x, xs):
    if x.lower().strip() == 'none':
        modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file='None')
    else:
        found = find_vae(x)
        if found:
            v = modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=found)
    modules.sd_vae.reload_vae_weights(shared.sd_model, vae_file=find_vae(x))


def apply_styles(p: StableDiffusionProcessingTxt2Img, x: str, _):
@@ -271,7 +268,9 @@ class SharedSettingsStackHelper(object):
  
    def __exit__(self, exc_type, exc_value, tb):
        modules.sd_models.reload_model_weights(self.model)
        modules.sd_vae.reload_vae_weights(self.model, vae_file=find_vae(self.vae))

        opts.data["sd_vae"] = self.vae
        modules.sd_vae.reload_vae_weights(self.model)

        hypernetwork.load_hypernetwork(self.hypernetwork)
        hypernetwork.apply_strength()