Unverified Commit 685f9631 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #5586 from wywywywy/ldsr-improvements

LDSR improvements - cache / optimization / opt_channelslast
parents 0a81dd52 1581d5a1
Loading
Loading
Loading
Loading
+34 −15
Original line number Diff line number Diff line
@@ -11,25 +11,41 @@ from omegaconf import OmegaConf

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack

warnings.filterwarnings("ignore", category=UserWarning)

cached_ldsr_model: torch.nn.Module = None


# Create LDSR Class
class LDSR:
    def load_model_from_config(self, half_attention):
        global cached_ldsr_model

        if shared.opts.ldsr_cached and cached_ldsr_model is not None:
            print(f"Loading model from cache")
            model: torch.nn.Module = cached_ldsr_model
        else:
            print(f"Loading model from {self.modelPath}")
            pl_sd = torch.load(self.modelPath, map_location="cpu")
            sd = pl_sd["state_dict"]
            config = OmegaConf.load(self.yamlPath)
            config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
        model = instantiate_from_config(config.model)
            model: torch.nn.Module = instantiate_from_config(config.model)
            model.load_state_dict(sd, strict=False)
        model.cuda()
            model = model.to(shared.device)
            if half_attention:
                model = model.half()
            if shared.cmd_opts.opt_channelslast:
                model = model.to(memory_format=torch.channels_last)

            sd_hijack.model_hijack.hijack(model) # apply optimization
            model.eval()

            if shared.opts.ldsr_cached:
                cached_ldsr_model = model

        return {"model": model}

    def __init__(self, model_path, yaml_path):
@@ -94,6 +110,7 @@ class LDSR:
        down_sample_method = 'Lanczos'

        gc.collect()
        if torch.cuda.is_available:
            torch.cuda.empty_cache()

        im_og = image
@@ -131,7 +148,9 @@ class LDSR:

        del model
        gc.collect()
        if torch.cuda.is_available:
            torch.cuda.empty_cache()

        return a


@@ -146,7 +165,7 @@ def get_cond(selected_path):
    c = rearrange(c, '1 c h w -> 1 h w c')
    c = 2. * c - 1.

    c = c.to(torch.device("cuda"))
    c = c.to(shared.device)
    example["LR_image"] = c
    example["image"] = c_up

+1 −0
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ def on_ui_settings():
    import gradio as gr

    shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
    shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))


script_callbacks.on_ui_settings(on_ui_settings)