Commit 4cbbb881 authored by Φφ's avatar Φφ
Browse files

Unload checkpoints on Request

…to free VRAM.

New Action buttons in the settings to manually free and reload checkpoints, essentially
juggling models between RAM and VRAM.
parent a9fed7c3
Loading
Loading
Loading
Loading
+13 −1
Original line number Original line Diff line number Diff line
@@ -18,7 +18,7 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin,Image
from PIL import PngImagePlugin,Image
from modules.sd_models import checkpoints_list
from modules.sd_models import checkpoints_list, unload_model_weights, reload_model_weights
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
from modules import devices
@@ -150,6 +150,8 @@ class Api:
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/embedding", self.train_embedding, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/train/hypernetwork", self.train_hypernetwork, methods=["POST"], response_model=TrainResponse)
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
        self.add_api_route("/sdapi/v1/memory", self.get_memory, methods=["GET"], response_model=MemoryResponse)
        self.add_api_route("/sdapi/v1/unload-checkpoint", self.unloadapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/reload-checkpoint", self.reloadapi, methods=["POST"])
        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)
        self.add_api_route("/sdapi/v1/scripts", self.get_scripts_list, methods=["GET"], response_model=ScriptsList)


    def add_api_route(self, path: str, endpoint, **kwargs):
    def add_api_route(self, path: str, endpoint, **kwargs):
@@ -412,6 +414,16 @@ class Api:


        return {}
        return {}


    def unloadapi(self):
        unload_model_weights()

        return {}

    def reloadapi(self):
        reload_model_weights()

        return {}

    def skip(self):
    def skip(self):
        shared.state.skip()
        shared.state.skip()


+21 −1
Original line number Original line Diff line number Diff line
@@ -494,7 +494,7 @@ def reload_model_weights(sd_model=None, info=None):
    if sd_model is None or checkpoint_config != sd_model.used_config:
    if sd_model is None or checkpoint_config != sd_model.used_config:
        del sd_model
        del sd_model
        checkpoints_loaded.clear()
        checkpoints_loaded.clear()
        load_model(checkpoint_info, already_loaded_state_dict=state_dict, time_taken_to_load_state_dict=timer.records["load weights from disk"])
        load_model(checkpoint_info, already_loaded_state_dict=state_dict)
        return shared.sd_model
        return shared.sd_model


    try:
    try:
@@ -517,3 +517,23 @@ def reload_model_weights(sd_model=None, info=None):
    print(f"Weights loaded in {timer.summary()}.")
    print(f"Weights loaded in {timer.summary()}.")


    return sd_model
    return sd_model

def unload_model_weights(sd_model=None, info=None):
    from modules import lowvram, devices, sd_hijack
    timer = Timer()

    if shared.sd_model:

        # shared.sd_model.cond_stage_model.to(devices.cpu)
        # shared.sd_model.first_stage_model.to(devices.cpu)
        shared.sd_model.to(devices.cpu)
        sd_hijack.model_hijack.undo_hijack(shared.sd_model)
        shared.sd_model = None
        sd_model = None
        gc.collect()
        devices.torch_gc()
        torch.cuda.empty_cache()

    print(f"Unloaded weights {timer.summary()}.")

    return sd_model
 No newline at end of file
+22 −0
Original line number Original line Diff line number Diff line
@@ -1491,12 +1491,34 @@ def create_ui():
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                request_notifications = gr.Button(value='Request browser notifications', elem_id="request_notifications")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
                reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
                with gr.Row():
                    unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
                    reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")


            with gr.TabItem("Licenses"):
            with gr.TabItem("Licenses"):
                gr.HTML(shared.html("licenses.html"), elem_id="licenses")
                gr.HTML(shared.html("licenses.html"), elem_id="licenses")


            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
            gr.Button(value="Show all pages", elem_id="settings_show_all_pages")
            
            

        def unload_sd_weights():
            modules.sd_models.unload_model_weights()

        def reload_sd_weights():
            modules.sd_models.reload_model_weights()

        unload_sd_model.click(
            fn=unload_sd_weights,
            inputs=[],
            outputs=[]
        )

        reload_sd_model.click(
            fn=reload_sd_weights,
            inputs=[],
            outputs=[]
        )

        request_notifications.click(
        request_notifications.click(
            fn=lambda: None,
            fn=lambda: None,
            inputs=[],
            inputs=[],