Commit 050a6a79 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

support loading .yaml config with same name as model

support EMA weights in processing (????)
parent 43278216
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -347,7 +347,7 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
    infotexts = []
    output_images = []

    with torch.no_grad():
    with torch.no_grad(), p.sd_model.ema_scope():
        with devices.autocast():
            p.init(all_prompts, all_seeds, all_subseeds)

+23 −7
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ from modules.paths import models_path
model_dir = "Stable-diffusion"
model_path = os.path.abspath(os.path.join(models_path, model_dir))

CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name'])
CheckpointInfo = namedtuple("CheckpointInfo", ['filename', 'title', 'hash', 'model_name', 'config'])
checkpoints_list = {}

try:
@@ -63,14 +63,20 @@ def list_models():
    if os.path.exists(cmd_ckpt):
        h = model_hash(cmd_ckpt)
        title, short_model_name = modeltitle(cmd_ckpt, h)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name)
        checkpoints_list[title] = CheckpointInfo(cmd_ckpt, title, h, short_model_name, shared.cmd_opts.config)
        shared.opts.data['sd_model_checkpoint'] = title
    elif cmd_ckpt is not None and cmd_ckpt != shared.default_sd_model_file:
        print(f"Checkpoint in --ckpt argument not found (Possible it was moved to {model_path}: {cmd_ckpt}", file=sys.stderr)
    for filename in model_list:
        h = model_hash(filename)
        title, short_model_name = modeltitle(filename, h)
        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name)

        basename, _ = os.path.splitext(filename)
        config = basename + ".yaml"
        if not os.path.exists(config):
            config = shared.cmd_opts.config

        checkpoints_list[title] = CheckpointInfo(filename, title, h, short_model_name, config)


def get_closet_checkpoint_match(searchString):
@@ -116,7 +122,10 @@ def select_checkpoint():
    return checkpoint_info


def load_model_weights(model, checkpoint_file, sd_model_hash):
def load_model_weights(model, checkpoint_info):
    checkpoint_file = checkpoint_info.filename
    sd_model_hash = checkpoint_info.hash

    print(f"Loading weights [{sd_model_hash}] from {checkpoint_file}")

    pl_sd = torch.load(checkpoint_file, map_location="cpu")
@@ -148,15 +157,19 @@ def load_model_weights(model, checkpoint_file, sd_model_hash):

    model.sd_model_hash = sd_model_hash
    model.sd_model_checkpoint = checkpoint_file
    model.sd_checkpoint_info = checkpoint_info


def load_model():
    from modules import lowvram, sd_hijack
    checkpoint_info = select_checkpoint()

    sd_config = OmegaConf.load(shared.cmd_opts.config)
    if checkpoint_info.config != shared.cmd_opts.config:
        print(f"Loading config from: {shared.cmd_opts.config}")

    sd_config = OmegaConf.load(checkpoint_info.config)
    sd_model = instantiate_from_config(sd_config.model)
    load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
    load_model_weights(sd_model, checkpoint_info)

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.setup_for_low_vram(sd_model, shared.cmd_opts.medvram)
@@ -178,6 +191,9 @@ def reload_model_weights(sd_model, info=None):
    if sd_model.sd_model_checkpoint == checkpoint_info.filename:
        return

    if sd_model.sd_checkpoint_info.config != checkpoint_info.config:
        return load_model()

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
        lowvram.send_everything_to_cpu()
    else:
@@ -185,7 +201,7 @@ def reload_model_weights(sd_model, info=None):

    sd_hijack.model_hijack.undo_hijack(sd_model)

    load_model_weights(sd_model, checkpoint_info.filename, checkpoint_info.hash)
    load_model_weights(sd_model, checkpoint_info)

    sd_hijack.model_hijack.hijack(sd_model)