Commit 042e1d5d authored by Uminosachi's avatar Uminosachi
Browse files

Fix SD VAE switch error after model reuse

parent 9d2299ed
Loading
Loading
Loading
Loading
+20 −2
Original line number Diff line number Diff line
@@ -462,6 +462,7 @@ class SdModelData:
    def __init__(self):
        self.sd_model = None
        self.loaded_sd_models = []
        self.loaded_vae_states = {}
        self.was_loaded_at_least_once = False
        self.lock = threading.Lock()

@@ -485,16 +486,27 @@ class SdModelData:

        return self.sd_model

    def set_sd_model(self, v):
    def set_sd_model(self, v, already_loaded=False):
        self.sd_model = v
        if already_loaded:
            sd_vae_state = self.loaded_vae_states.get(v.sd_model_hash, {})
            sd_vae.base_vae = sd_vae_state.get("base_vae", None)
            sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None)
            sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None)

        try:
            self.loaded_sd_models.remove(v)
            self.loaded_vae_states.pop(v.sd_model_hash, {}).clear()
        except ValueError:
            pass

        if v is not None:
            self.loaded_sd_models.insert(0, v)
            self.loaded_vae_states[v.sd_model_hash] = dict(
                base_vae=sd_vae.base_vae,
                loaded_vae_file=sd_vae.loaded_vae_file,
                checkpoint_info=sd_vae.checkpoint_info,
            )


model_data = SdModelData()
@@ -649,6 +661,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
        if len(model_data.loaded_sd_models) > shared.opts.sd_checkpoints_limit > 0:
            print(f"Unloading model {len(model_data.loaded_sd_models)} over the limit of {shared.opts.sd_checkpoints_limit}: {loaded_model.sd_checkpoint_info.title}")
            model_data.loaded_sd_models.pop()
            model_data.loaded_vae_states.pop(loaded_model.sd_model_hash, {}).clear()
            send_model_to_trash(loaded_model)
            timer.record("send model to trash")

@@ -660,7 +673,7 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
        send_model_to_device(already_loaded)
        timer.record("send model to device")

        model_data.set_sd_model(already_loaded)
        model_data.set_sd_model(already_loaded, already_loaded=True)

        if not SkipWritingToConfig.skip:
            shared.opts.data["sd_model_checkpoint"] = already_loaded.sd_checkpoint_info.title
@@ -678,6 +691,11 @@ def reuse_model_from_already_loaded(sd_model, checkpoint_info, timer):
        sd_model = model_data.loaded_sd_models.pop()
        model_data.sd_model = sd_model

        sd_vae_state = model_data.loaded_vae_states.pop(sd_model.sd_model_hash, {})
        sd_vae.base_vae = sd_vae_state.get("base_vae", None)
        sd_vae.loaded_vae_file = sd_vae_state.get("loaded_vae_file", None)
        sd_vae.checkpoint_info = sd_vae_state.get("checkpoint_info", None)

        print(f"Reusing loaded model {sd_model.sd_checkpoint_info.title} to load {checkpoint_info.title}")
        return sd_model
    else: