Commit 8d8a05a3 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

find configs for models at runtime rather than when starting

parent 02d7abf5
Loading
Loading
Loading
Loading
+4 −1
Original line number Diff line number Diff line
@@ -97,8 +97,11 @@ def p_sample_plms(self, x, c, t, index, repeat_noise=False, use_original_steps=F


def should_hijack_inpainting(checkpoint_info):
    from modules import sd_models

    ckpt_basename = os.path.basename(checkpoint_info.filename).lower()
    cfg_basename = os.path.basename(checkpoint_info.config).lower()
    cfg_basename = os.path.basename(sd_models.find_checkpoint_config(checkpoint_info)).lower()

    return "inpainting" in ckpt_basename and not "inpainting" in cfg_basename


+18 −13
Original line number Diff line number Diff line
@@ -20,7 +20,7 @@ from modules.sd_hijack_inpainting import do_inpainting_hijack, should_hijack_inp
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))

CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
checkpoints_list = {}
checkpoints_loaded = collections.OrderedDict()

@@ -48,6 +48,14 @@ def checkpoint_tiles():
    return sorted([x.title for x in checkpoints_list.values()], key = alphanumeric_key)


def find_checkpoint_config(info):
    config = os.path.splitext(info.filename)[0] + ".yaml"
    if os.path.exists(config):
        return config

    return shared.cmd_opts.config


def list_models():
    checkpoints_list.clear()
    model_list = modelloader.load_models(model_path=model_path, command_path=shared.cmd_opts.ckpt_dir, ext_filter=[".ckpt", ".safetensors"])
@@ -73,7 +81,7 @@ def list_models():
    if os.path.exists(cmd_ckpt):
        h = model_hash(cmd_ckpt)
        title, short_model_name = modeltitle(cmd_ckpt, h)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
        shared.opts.data['sd_model_checkpoint'] = title
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
        print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
@@ -81,12 +89,7 @@ def list_models():
        h = model_hash(filename)
        title, short_model_name = modeltitle(filename, h)

        basename, _ = os.path.splitext(filename)
        config = basename + ".yaml"
        if not os.path.exists(config):
            config = shared.cmd_opts.config

        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)
        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)


def get_closet_checkpoint_match(searchString):
@@ -282,9 +285,10 @@ def enable_midas_autodownload():
def load_model(checkpoint_info=None):
    from modules import lowvram, sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()
    checkpoint_config = find_checkpoint_config(checkpoint_info)

    if checkpoint_info.config != shared.cmd_opts.config:
        print(f"Loading config from: {checkpoint_info.config}")
    if checkpoint_config != shared.cmd_opts.config:
        print(f"Loading config from: {checkpoint_config}")

    if shared.sd_model:
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
@@ -292,7 +296,7 @@ def load_model(checkpoint_info=None):
        gc.collect()
        devices.torch_gc()

    sd_config = OmegaConf.load(checkpoint_info.config)
    sd_config = OmegaConf.load(checkpoint_config)
    
    if should_hijack_inpainting(checkpoint_info):
        # Hardcoded config for now...
@@ -302,7 +306,7 @@ def load_model(checkpoint_info=None):
        sd_config.model.params.finetune_keys = None

        # Create a "fake" config with a different name so that we know to unload it when switching models.
        checkpoint_info = checkpoint_info._replace(config=checkpoint_info.config.replace(".yaml", "-inpainting.yaml"))
        checkpoint_info = checkpoint_info._replace(config=checkpoint_config.replace(".yaml", "-inpainting.yaml"))

    if not hasattr(sd_config.model.params, "use_ema"):
        sd_config.model.params.use_ema = False
@@ -343,11 +347,12 @@ def reload_model_weights(sd_model=None, info=None):
        sd_model = shared.sd_model

    current_checkpoint_info = sd_model.sd_checkpoint_info
    checkpoint_config = find_checkpoint_config(current_checkpoint_info)

    if sd_model.sd_model_checkpoint == checkpoint_info.filename:
        return

    if sd_model.sd_checkpoint_info.config != checkpoint_info.config or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
    if checkpoint_config != find_checkpoint_config(checkpoint_info) or should_hijack_inpainting(checkpoint_info) != should_hijack_inpainting(sd_model.sd_checkpoint_info):
        del sd_model
        checkpoints_loaded.clear()
        load_model(checkpoint_info)