Commit e14b586d authored by Sakura-Luna's avatar Sakura-Luna
Browse files

Add Tiny AE live preview

parent b08500ce
Loading
Loading
Loading
Loading
+13 −8
Original line number Original line Diff line number Diff line
@@ -2,7 +2,7 @@ from collections import namedtuple
import numpy as np
import numpy as np
import torch
import torch
from PIL import Image
from PIL import Image
from modules import devices, processing, images, sd_vae_approx
from modules import devices, processing, images, sd_vae_approx, sd_vae_taesd


from modules.shared import opts, state
from modules.shared import opts, state
import modules.shared as shared
import modules.shared as shared
@@ -22,21 +22,26 @@ def setup_img2img_steps(p, steps=None):
    return steps, t_enc
    return steps, t_enc




approximation_indexes = {"Full": 0, "Approx NN": 1, "Approx cheap": 2}
approximation_indexes = {"Full": 0, "Tiny AE": 1, "Approx NN": 2, "Approx cheap": 3}




def single_sample_to_image(sample, approximation=None):
def single_sample_to_image(sample, approximation=None):
    if approximation is None:
    if approximation is None:
        approximation = approximation_indexes.get(opts.show_progress_type, 0)
        approximation = approximation_indexes.get(opts.show_progress_type, 0)


    if approximation == 2:
    if approximation == 1:
        x_sample = sd_vae_taesd.decode()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
        x_sample = sd_vae_taesd.TAESD.unscale_latents(x_sample)
        x_sample = torch.clamp((x_sample * 0.25) + 0.5, 0, 1)
    else:
        if approximation == 3:
            x_sample = sd_vae_approx.cheap_approximation(sample)
            x_sample = sd_vae_approx.cheap_approximation(sample)
    elif approximation == 1:
        elif approximation == 2:
            x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
            x_sample = sd_vae_approx.model()(sample.to(devices.device, devices.dtype).unsqueeze(0))[0].detach()
        else:
        else:
            x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]
            x_sample = processing.decode_first_stage(shared.sd_model, sample.unsqueeze(0))[0]

        x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)
        x_sample = torch.clamp((x_sample + 1.0) / 2.0, min=0.0, max=1.0)

    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
    x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)
    x_sample = x_sample.astype(np.uint8)
    x_sample = x_sample.astype(np.uint8)
    return Image.fromarray(x_sample)
    return Image.fromarray(x_sample)
+76 −0
Original line number Original line Diff line number Diff line
"""
Tiny AutoEncoder for Stable Diffusion
(DNN for encoding / decoding SD's latent space)

https://github.com/madebyollin/taesd
"""
import os
import torch
import torch.nn as nn

from modules import devices, paths_internal

sd_vae_taesd = None


def conv(n_in, n_out, **kwargs):
    return nn.Conv2d(n_in, n_out, 3, padding=1, **kwargs)


class Clamp(nn.Module):
    @staticmethod
    def forward(x):
        return torch.tanh(x / 3) * 3


class Block(nn.Module):
    def __init__(self, n_in, n_out):
        super().__init__()
        self.conv = nn.Sequential(conv(n_in, n_out), nn.ReLU(), conv(n_out, n_out), nn.ReLU(), conv(n_out, n_out))
        self.skip = nn.Conv2d(n_in, n_out, 1, bias=False) if n_in != n_out else nn.Identity()
        self.fuse = nn.ReLU()

    def forward(self, x):
        return self.fuse(self.conv(x) + self.skip(x))


def decoder():
    return nn.Sequential(
        Clamp(), conv(4, 64), nn.ReLU(),
        Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
        Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
        Block(64, 64), Block(64, 64), Block(64, 64), nn.Upsample(scale_factor=2), conv(64, 64, bias=False),
        Block(64, 64), conv(64, 3),
    )


class TAESD(nn.Module):
    latent_magnitude = 2
    latent_shift = 0.5

    def __init__(self, decoder_path="taesd_decoder.pth"):
        """Initialize pretrained TAESD on the given device from the given checkpoints."""
        super().__init__()
        self.decoder = decoder()
        self.decoder.load_state_dict(
            torch.load(decoder_path, map_location='cpu' if devices.device.type != 'cuda' else None))

    @staticmethod
    def unscale_latents(x):
        """[0, 1] -> raw latents"""
        return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)


def decode():
    global sd_vae_taesd

    if sd_vae_taesd is None:
        model_path = os.path.join(paths_internal.models_path, "VAE-approx", "taesd_decoder.pth")
        if os.path.exists(model_path):
            sd_vae_taesd = TAESD(model_path)
            sd_vae_taesd.eval()
            sd_vae_taesd.to(devices.device, devices.dtype)
        else:
            raise FileNotFoundError('Tiny AE mdoel not found')

    return sd_vae_taesd.decoder
+1 −1
Original line number Original line Diff line number Diff line
@@ -425,7 +425,7 @@ options_templates.update(options_section(('ui', "Live previews"), {
    "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
    "live_previews_enable": OptionInfo(True, "Show live previews of the created image"),
    "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
    "show_progress_grid": OptionInfo(True, "Show previews of all images generated in a batch as a grid"),
    "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
    "show_progress_every_n_steps": OptionInfo(10, "Show new live preview image every N sampling steps. Set to -1 to show after completion of batch.", gr.Slider, {"minimum": -1, "maximum": 32, "step": 1}),
    "show_progress_type": OptionInfo("Approx NN", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Approx NN", "Approx cheap"]}),
    "show_progress_type": OptionInfo("Tiny AE", "Image creation progress preview mode", gr.Radio, {"choices": ["Full", "Tiny AE", "Approx NN", "Approx cheap"]}),
    "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
    "live_preview_content": OptionInfo("Prompt", "Live preview subject", gr.Radio, {"choices": ["Combined", "Prompt", "Negative prompt"]}),
    "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
    "live_preview_refresh_period": OptionInfo(1000, "Progressbar/preview update period, in milliseconds")
}))
}))
+11 −0
Original line number Original line Diff line number Diff line
@@ -144,10 +144,21 @@ Use --skip-version-check commandline argument to disable this check.
            """.strip())
            """.strip())




def check_taesd():
    from modules.paths_internal import models_path

    model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'
    model_path = os.path.join(models_path, "VAE-approx", "taesd_decoder.pth")
    if not os.path.exists(model_path):
        print('download taesd model')
        torch.hub.download_url_to_file(model_url, os.path.dirname(model_path))


def initialize():
def initialize():
    fix_asyncio_event_loop_policy()
    fix_asyncio_event_loop_policy()


    check_versions()
    check_versions()
    check_taesd()


    extensions.list_extensions()
    extensions.list_extensions()
    localization.list_localizations(cmd_opts.localizations_dir)
    localization.list_localizations(cmd_opts.localizations_dir)