Commit 02d7abf5 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

helpful error message when trying to load 2.0 without config

failing to load model weights from settings won't break generation for currently loaded model anymore
parent 7e549468
Loading
Loading
Loading
Loading
+23 −2
Original line number Diff line number Diff line
@@ -2,9 +2,30 @@ import sys
import traceback


def print_error_explanation(message):
    lines = message.strip().split("\n")
    max_len = max([len(x) for x in lines])

    print('=' * max_len, file=sys.stderr)
    for line in lines:
        print(line, file=sys.stderr)
    print('=' * max_len, file=sys.stderr)


def display(e: Exception, task):
    print(f"{task or 'error'}: {type(e).__name__}", file=sys.stderr)
    print(traceback.format_exc(), file=sys.stderr)

    message = str(e)
    if "copying a param with shape torch.Size([640, 1024]) from checkpoint, the shape in current model is torch.Size([640, 768])" in message:
        print_error_explanation("""
The most likely cause of this is you are trying to load Stable Diffusion 2.0 model without specifying its connfig file.
See https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Features#stable-diffusion-20 for how to solve this.
        """)


def run(code, task):
    try:
        code()
    except Exception as e:
        print(f"{task}: {type(e).__name__}", file=sys.stderr)
        print(traceback.format_exc(), file=sys.stderr)
        display(task, e)
+18 −8
Original line number Diff line number Diff line
@@ -278,6 +278,7 @@ def enable_midas_autodownload():

    midas.api.load_model = load_model_wrapper


def load_model(checkpoint_info=None):
    from modules import lowvram, sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()
@@ -312,6 +313,7 @@ def load_model(checkpoint_info=None):
        sd_config.model.params.unet_config.params.use_fp16 = False

    sd_model = instantiate_from_config(sd_config.model)

    load_model_weights(sd_model, checkpoint_info)

    if shared.cmd_opts.lowvram or shared.cmd_opts.medvram:
@@ -340,6 +342,8 @@ def reload_model_weights(sd_model=None, info=None):
    if not sd_model:
        sd_model = shared.sd_model

    current_checkpoint_info = sd_model.sd_checkpoint_info

    if sd_model.sd_model_checkpoint == checkpoint_info.filename:
        return

@@ -356,8 +360,13 @@ def reload_model_weights(sd_model=None, info=None):

    sd_hijack.model_hijack.undo_hijack(sd_model)

    try:
        load_model_weights(sd_model, checkpoint_info)

    except Exception as e:
        print("Failed to load checkpoint, restoring previous")
        load_model_weights(sd_model, current_checkpoint_info)
        raise
    finally:
        sd_hijack.model_hijack.hijack(sd_model)
        script_callbacks.model_loaded_callback(sd_model)

@@ -365,4 +374,5 @@ def reload_model_weights(sd_model=None, info=None):
            sd_model.to(devices.device)

    print("Weights loaded.")

    return sd_model
+7 −2
Original line number Diff line number Diff line
@@ -14,7 +14,7 @@ import modules.interrogate
import modules.memmon
import modules.styles
import modules.devices as devices
from modules import localization, sd_vae, extensions, script_loading
from modules import localization, sd_vae, extensions, script_loading, errors
from modules.paths import models_path, script_path, sd_path


@@ -494,7 +494,12 @@ class Options:
            return False

        if self.data_labels[key].onchange is not None:
            try:
                self.data_labels[key].onchange()
            except Exception as e:
                errors.display(e, f"changing setting {key} to {value}")
                setattr(self, key, oldval)
                return False

        return True

+10 −2
Original line number Diff line number Diff line
@@ -9,7 +9,7 @@ from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware

from modules import import_hook
from modules import import_hook, errors
from modules.call_queue import wrap_queued_call, queue_lock, wrap_gradio_gpu_call
from modules.paths import script_path

@@ -61,7 +61,15 @@ def initialize():
    modelloader.load_upscalers()

    modules.sd_vae.refresh_vae_list()

    try:
        modules.sd_models.load_model()
    except Exception as e:
        errors.display(e, "loading stable diffusion model")
        print("", file=sys.stderr)
        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
        exit(1)

    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)