Commit 4cbbb881 authored by Φφ's avatar Φφ
Browse files

Unload checkpoints on Request

…to free VRAM.

New Action buttons in the settings to manually free and reload checkpoints, essentially
juggling models between RAM and VRAM.
parent a9fed7c3
Loading
Loading
Loading
Loading
+13 −1
Original line number Diff line number Diff line
@@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -150,6 +150,8 @@ class Api:
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
        self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)

    def add_api_route(self, path: str, endpoint, **kwargs):
@@ -412,6 +414,16 @@ class Api:

        return {}

    def unloadapi(self):
        unload_model_weights()

        return {}

    def reloadapi(self):
        reload_model_weights()

        return {}

    def skip(self):
        shared.state.skip()

+21 −1
Original line number Diff line number Diff line
@@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
    if sd_model is None or checkpoint_config != sd_model.used_config:
        del sd_model
        checkpoints_loaded.clear()
        load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
        return shared.sd_model

    try:
@@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
    print(f"Weights loaded in {timer.summary()}.")

    return sd_model

def unload_model_weights(sd_model=None, info=None):
    from modules import lowvram, devices, sd_hijack
    timer = Timer()

    if shared.sd_model:

        # shared.sd_model.cond_stage_model.to(devices.cpu)
        # shared.sd_model.first_stage_model.to(devices.cpu)
        shared.sd_model.to(devices.cpu)
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
        shared.sd_model = None
        sd_model = None
        gc.collect()
        devices.torch_gc()
        torch.cuda.empty_cache()

    print(f"Unloaded weights {timer.summary()}.")

    return sd_model
 No newline at end of file
+22 −0
Original line number Diff line number Diff line
@@ -1491,12 +1491,34 @@ def create_ui():
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
                with gr.Row():
                    unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
                    reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")

            with gr.TabItem("Licenses"):
                gr.HTML(shared.html("licenses.html"), elem_id="licenses")

            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
            

        def unload_sd_weights():
            modules.sd_models.unload_model_weights()

        def reload_sd_weights():
            modules.sd_models.reload_model_weights()

        unload_sd_model.click(
            fn=unload_sd_weights,
            inputs=[],
            outputs=[]
        )

        reload_sd_model.click(
            fn=reload_sd_weights,
            inputs=[],
            outputs=[]
        )

        request_notifications.click(
            fn=lambda: None,
            inputs=[],