Commit 64fd9163 authored by Kohaku-Blueleaf's avatar Kohaku-Blueleaf
Browse files

Add consistency decoder

parent 9c1c0da0
Loading
Loading
Loading
Loading
+9 −2
Original line number Diff line number Diff line
@@ -3,7 +3,7 @@ from collections import namedtuple
import numpy as np
import torch
from PIL import Image
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, shared, sd_models
from modules import devices, images, sd_vae_approx, sd_samplers, sd_vae_taesd, sd_vae_consistency, shared, sd_models
from modules.shared import opts, state
import k_diffusion.sampling

@@ -31,7 +31,7 @@ def setup_img2img_steps(p, steps=None):
    return steps, t_enc


approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3}
approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2, "TAESD": 3, "Consistency Decoder": 4}


def samples_to_images_tensor(sample, approximation=None, model=None):
@@ -51,6 +51,13 @@ def samples_to_images_tensor(sample, approximation=None, model=None):
    elif approximation == 3:
        x_sample = sd_vae_taesd.decoder_model()(sample.to(devices.device, devices.dtype)).detach()
        x_sample = x_sample * 2 - 1
    elif approximation == 4:
        with devices.autocast(), torch.no_grad():
            x_sample = sd_vae_consistency.decoder_model()(
                sample.to(devices.device, devices.dtype)/0.18215,
                schedule=[1.0]
            )
        sd_vae_consistency.unload()
    else:
        if model is None:
            model = shared.sd_model
+35 −0
Original line number Diff line number Diff line
"""
Consistency Decoder
Improved decoding for stable diffusion vaes.

https://github.com/openai/consistencydecoder
"""
import os
import torch
import torch.nn as nn

from modules import devices, paths_internal, shared
from consistencydecoder import ConsistencyDecoder


sd_vae_consistency_models = None
model_path = os.path.join(paths_internal.models_path, 'consistencydecoder')


def decoder_model():
    global sd_vae_consistency_models
    if getattr(shared.sd_model, 'is_sdxl', False):
        raise NotImplementedError("SDXL is not supported for consistency decoder")
    if sd_vae_consistency_models is not None:
        sd_vae_consistency_models.ckpt.to(devices.device)
        return sd_vae_consistency_models

    loaded_model = ConsistencyDecoder(devices.device, model_path)
    sd_vae_consistency_models = loaded_model
    return loaded_model


def unload():
    global sd_vae_consistency_models
    if sd_vae_consistency_models is not None:
        sd_vae_consistency_models.ckpt.to('cpu')
 No newline at end of file
+1 −1
Original line number Diff line number Diff line
@@ -172,7 +172,7 @@ For img2img, VAE is used to process user's input image before the sampling, and
    "sd_vae_overrides_per_model_preferences": OptionInfo(True, "Selected VAE overrides per-model preferences").info("you can set per-model VAE either by editing user metadata for checkpoints, or by making the VAE have same name as checkpoint"),
    "auto_vae_precision": OptionInfo(True, "Automatically revert VAE to 32-bit floats").info("triggers when a tensor with NaNs is produced in VAE; disabling the option in this case will result in a black square image"),
    "sd_vae_encode_method": OptionInfo("Full", "VAE type for encode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Encoder').info("method to encode image to latent (use in img2img, hires-fix or inpaint mask)"),
    "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD"]}, infotext='VAE Decoder').info("method to decode latent to image"),
    "sd_vae_decode_method": OptionInfo("Full", "VAE type for decode", gr.Radio, {"choices": ["Full", "TAESD", "Consistency Decoder"]}, infotext='VAE Decoder').info("method to decode latent to image"),
}))

options_templates.update(options_section(('img2img', "img2img"), {
+2 −0
Original line number Diff line number Diff line
@@ -32,3 +32,5 @@ torch
torchdiffeq
torchsde
transformers==4.30.2

git+https://github.com/openai/consistencydecoder.git
+1 −0
Original line number Diff line number Diff line
@@ -30,3 +30,4 @@ torchdiffeq==0.2.3
torchsde==0.2.6
transformers==4.30.2
httpx==0.24.1
git+https://github.com/openai/consistencydecoder.git