Commit 282903bb authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

repair unload sd checkpoint button

parent 0d65d0ea
Loading
Loading
Loading
Loading
+5 −6
Original line number Diff line number Diff line
@@ -17,7 +17,7 @@ from fastapi.encoders import jsonable_encoder
from secrets import compare_digest

import modules.shared as shared
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste
from modules import sd_samplers, deepbooru, sd_hijack, images, scripts, ui, postprocessing, errors, restart, shared_items, script_callbacks, generation_parameters_copypaste, sd_models
from modules.api import models
from modules.shared import opts
from modules.processing import StableDiffusionProcessingTxt2Img, StableDiffusionProcessingImg2Img, process_images
@@ -25,7 +25,6 @@ from modules.textual_inversion.textual_inversion import create_embedding, train_
from modules.textual_inversion.preprocess import preprocess
from modules.hypernetworks.hypernetwork import create_hypernetwork, train_hypernetwork
from PIL import PngImagePlugin, Image
from modules.sd_models import unload_model_weights, reload_model_weights, checkpoint_aliases
from modules.sd_models_config import find_checkpoint_config_near_filename
from modules.realesrgan_model import get_realesrgan_models
from modules import devices
@@ -541,12 +540,12 @@ class Api:
        return {}

    def unloadapi(self):
        unload_model_weights()
        sd_models.unload_model_weights()

        return {}

    def reloadapi(self):
        reload_model_weights()
        sd_models.send_model_to_device(shared.sd_model)

        return {}

@@ -566,7 +565,7 @@ class Api:

    def set_config(self, req: dict[str, Any]):
        checkpoint_name = req.get("sd_model_checkpoint", None)
        if checkpoint_name is not None and checkpoint_name not in checkpoint_aliases:
        if checkpoint_name is not None and checkpoint_name not in sd_models.checkpoint_aliases:
            raise RuntimeError(f"model {checkpoint_name!r} not found")

        for k, v in req.items():
+1 −12
Original line number Diff line number Diff line
import collections
import os.path
import sys
import gc
import threading

import torch
@@ -798,17 +797,7 @@ def reload_model_weights(sd_model=None, info=None):


def unload_model_weights(sd_model=None, info=None):
    timer = Timer()

    if model_data.sd_model:
        model_data.sd_model.to(devices.cpu)
        sd_hijack.model_hijack.undo_hijack(model_data.sd_model)
        model_data.sd_model = None
        sd_model = None
        gc.collect()
        devices.torch_gc()

    print(f"Unloaded weights {timer.summary()}.")
    send_model_to_cpu(sd_model or shared.sd_model)

    return sd_model

+17 −7
Original line number Diff line number Diff line
import gradio as gr

from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo
from modules import ui_common, shared, script_callbacks, scripts, sd_models, sysinfo, timer
from modules.call_queue import wrap_gradio_call
from modules.shared import opts
from modules.ui_components import FormRow
@@ -177,8 +177,8 @@ class UiSettings:
                    download_localization = gr.Button(value='Download localization template', elem_id="download_localization")
                    reload_script_bodies = gr.Button(value='Reload custom script bodies (No ui updates, No restart)', variant='secondary', elem_id="settings_reload_script_bodies")
                    with gr.Row():
                        unload_sd_model = gr.Button(value='Unload SD checkpoint to free VRAM', elem_id="sett_unload_sd_model")
                        reload_sd_model = gr.Button(value='Reload the last SD checkpoint back into VRAM', elem_id="sett_reload_sd_model")
                        unload_sd_model = gr.Button(value='Unload SD checkpoint to RAM', elem_id="sett_unload_sd_model")
                        reload_sd_model = gr.Button(value='Load SD checkpoint to VRAM from RAM', elem_id="sett_reload_sd_model")
                    with gr.Row():
                        calculate_all_checkpoint_hash = gr.Button(value='Calculate hash for all checkpoint', elem_id="calculate_all_checkpoint_hash")
                        calculate_all_checkpoint_hash_threads = gr.Number(value=1, label="Number of parallel calculations", elem_id="calculate_all_checkpoint_hash_threads", precision=0, minimum=1)
@@ -194,16 +194,26 @@ class UiSettings:

                self.text_settings = gr.Textbox(elem_id="settings_json", value=lambda: opts.dumpjson(), visible=False)

            def call_func_and_return_text(func, text):
                def handler():
                    t = timer.Timer()
                    func()
                    t.record(text)

                    return f'{text} in {t.total:.1f}s'

                return handler

            unload_sd_model.click(
                fn=sd_models.unload_model_weights,
                fn=call_func_and_return_text(sd_models.unload_model_weights, 'Unloaded the checkpoint'),
                inputs=[],
                outputs=[]
                outputs=[self.result]
            )

            reload_sd_model.click(
                fn=sd_models.reload_model_weights,
                fn=call_func_and_return_text(lambda: sd_models.send_model_to_device(shared.sd_model), 'Loaded the checkpoint'),
                inputs=[],
                outputs=[]
                outputs=[self.result]
            )

            request_notifications.click(