Commit b1717c0a authored by AUTOMATIC's avatar AUTOMATIC
Browse files

do not load wait for shared.sd_model to load at startup

parent 696c338e
Loading
Loading
Loading
Loading
+40 −14
Original line number Diff line number Diff line
@@ -2,6 +2,8 @@ import collections
import os.path
import sys
import gc
import threading

import torch
import re
import safetensors.torch
@@ -404,13 +406,39 @@ def repair_config(sd_config):
sd1_clip_weight = 'cond_stage_model.transformer.text_model.embeddings.token_embedding.weight'
sd2_clip_weight = 'cond_stage_model.model.transformer.resblocks.0.attn.in_proj_weight'

def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_to_load_state_dict=None):

class SdModelData:
    def __init__(self):
        self.sd_model = None
        self.lock = threading.Lock()

    def get_sd_model(self):
        if self.sd_model is None:
            with self.lock:
                try:
                    load_model()
                except Exception as e:
                    errors.display(e, "loading stable diffusion model")
                    print("", file=sys.stderr)
                    print("Stable diffusion model failed to load", file=sys.stderr)
                    self.sd_model = None

        return self.sd_model

    def set_sd_model(self, v):
        self.sd_model = v


model_data = SdModelData()


def load_model(checkpoint_info=None, already_loaded_state_dict=None):
    from modules import lowvram, sd_hijack
    checkpoint_info = checkpoint_info or select_checkpoint()

    if shared.sd_model:
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
        shared.sd_model = None
    if model_data.sd_model:
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        model_data.sd_model = None
        gc.collect()
        devices.torch_gc()

@@ -464,7 +492,7 @@ def load_model(checkpoint_info=None, already_loaded_state_dict=None, time_taken_
    timer.record("hijack")

    sd_model.eval()
    shared.sd_model = sd_model
    model_data.sd_model = sd_model

    sd_hijack.model_hijack.embedding_db.load_textual_inversion_embeddings(force_reload=True)  # Reload embeddings after model load as they may or may not fit the model

@@ -484,7 +512,7 @@ def reload_model_weights(sd_model=None, info=None):
    checkpoint_info = info or select_checkpoint()

    if not sd_model:
        sd_model = shared.sd_model
        sd_model = model_data.sd_model

    if sd_model is None:  # previous model load failed
        current_checkpoint_info = None
@@ -512,7 +540,7 @@ def reload_model_weights(sd_model=None, info=None):
        del sd_model
        checkpoints_loaded.clear()
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
        return shared.sd_model
        return model_data.sd_model

    try:
        load_model_weights(sd_model, checkpoint_info, state_dict, timer)
@@ -535,17 +563,15 @@ def reload_model_weights(sd_model=None, info=None):

    return sd_model


def unload_model_weights(sd_model=None, info=None):
    from modules import lowvram, devices, sd_hijack
    timer = Timer()

    if shared.sd_model:

        # shared.sd_model.cond_stage_model.to(devices.cpu)
        # shared.sd_model.first_stage_model.to(devices.cpu)
        shared.sd_model.to(devices.cpu)
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
        shared.sd_model = None
    if model_data.sd_model:
        model_data.sd_model.to(devices.cpu)
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        model_data.sd_model = None
        sd_model = None
        gc.collect()
        devices.torch_gc()
+27 −4
Original line number Diff line number Diff line
@@ -16,6 +16,7 @@ import modules.styles
import modules.devices as devices
from modules import localization, script_loading, errors, ui_components, shared_items, cmd_args
from modules.paths_internal import models_path, script_path, data_path, sd_configs_path, sd_default_config, sd_model_file, default_sd_model_file, extensions_dir, extensions_builtin_dir
from ldm.models.diffusion.ddpm import LatentDiffusion

demo = None

@@ -600,13 +601,37 @@ class Options:
        return value



opts = Options()
if os.path.exists(config_filename):
    opts.load(config_filename)


class Shared(sys.modules[__name__].__class__):
    """
    this class is here to provide sd_model field as a property, so that it can be created and loaded on demand rather than
    at program startup.
    """

    sd_model_val = None

    @property
    def sd_model(self):
        import modules.sd_models

        return modules.sd_models.model_data.get_sd_model()

    @sd_model.setter
    def sd_model(self, value):
        import modules.sd_models

        modules.sd_models.model_data.set_sd_model(value)


sd_model: LatentDiffusion = None  # this var is here just for IDE's type checking; it cannot be accessed because the class field above will be accessed instead
sys.modules[__name__].__class__ = Shared

settings_components = None
"""assinged from ui.py, a mapping on setting anmes to gradio components repsponsible for those settings"""
"""assinged from ui.py, a mapping on setting names to gradio components repsponsible for those settings"""

latent_upscale_default_mode = "Latent"
latent_upscale_modes = {
@@ -620,8 +645,6 @@ latent_upscale_modes = {

sd_upscalers = []

sd_model = None

clip_model = None

progress_print_out = sys.stdout
+4 −6
Original line number Diff line number Diff line
@@ -828,7 +828,7 @@ def create_ui():
                        with FormGroup():
                            with FormRow():
                                cfg_scale = gr.Slider(minimum=1.0, maximum=30.0, step=0.5, label='CFG Scale', value=7.0, elem_id="img2img_cfg_scale")
                                image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
                                image_cfg_scale = gr.Slider(minimum=0, maximum=3.0, step=0.05, label='Image CFG Scale', value=1.5, elem_id="img2img_image_cfg_scale", visible=False)
                            denoising_strength = gr.Slider(minimum=0.0, maximum=1.0, step=0.01, label='Denoising strength', value=0.75, elem_id="img2img_denoising_strength")

                    elif category == "seed":
@@ -1693,11 +1693,9 @@ def create_ui():
                show_progress=info.refresh is not None,
            )

        text_settings.change(
            fn=lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit"),
            inputs=[],
            outputs=[image_cfg_scale],
        )
        update_image_cfg_scale_visibility = lambda: gr.update(visible=shared.sd_model and shared.sd_model.cond_stage_key == "edit")
        text_settings.change(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])
        demo.load(fn=update_image_cfg_scale_visibility, inputs=[], outputs=[image_cfg_scale])

        button_set_checkpoint = gr.Button('Change checkpoint', elem_id='change_checkpoint', visible=False)
        button_set_checkpoint.click(
+5 −11
Original line number Diff line number Diff line
@@ -6,6 +6,8 @@ import signal
import re
import warnings
import json
from threading import Thread

from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.middleware.gzip import GZipMiddleware
@@ -191,18 +193,10 @@ def initialize():
    modules.textual_inversion.textual_inversion.list_textual_inversion_templates()
    startup_timer.record("refresh textual inversion templates")

    try:
        modules.sd_models.load_model()
    except Exception as e:
        errors.display(e, "loading stable diffusion model")
        print("", file=sys.stderr)
        print("Stable diffusion model failed to load, exiting", file=sys.stderr)
        exit(1)
    startup_timer.record("load SD checkpoint")

    shared.opts.data["sd_model_checkpoint"] = shared.sd_model.sd_checkpoint_info.title
    # load model in parallel to other startup stuff
    Thread(target=lambda: shared.sd_model).start()

    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()))
    shared.opts.onchange("sd_model_checkpoint", wrap_queued_call(lambda: modules.sd_models.reload_model_weights()), call=False)
    shared.opts.onchange("sd_vae", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("sd_vae_as_default", wrap_queued_call(lambda: modules.sd_vae.reload_vae_weights()), call=False)
    shared.opts.onchange("temp_dir", ui_tempdir.on_tmpdir_changed)