Commit b8159d09 authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

add XL support for live previews: approx and TAESD

parent 6f23da60
Loading
Loading
Loading
Loading
+1 −1
Original line number Diff line number Diff line
@@ -48,7 +48,7 @@ def extend_sdxl(model):
    discretization = sgm.modules.diffusionmodules.discretizer.LegacyDDPMDiscretization()
    model.alphas_cumprod = torch.asarray(discretization.alphas_cumprod, device=devices.device, dtype=dtype)

    model.is_xl = True
    model.is_sdxl = True


sgm.models.diffusion.DiffusionEngine.get_learned_conditioning = get_learned_conditioning
+26 −11
Original line number Diff line number Diff line
@@ -2,9 +2,9 @@ import os

import torch
from torch import nn
from modules import devices, paths
from modules import devices, paths, shared

sd_vae_approx_model = None
sd_vae_approx_models = {}


class VAEApprox(nn.Module):
@@ -31,19 +31,34 @@ class VAEApprox(nn.Module):
        return x


def download_model(model_path, model_url):
    if not os.path.exists(model_path):
        os.makedirs(os.path.dirname(model_path), exist_ok=True)

        print(f'Downloading VAEApprox model to: {model_path}')
        torch.hub.download_url_to_file(model_url, model_path)


def model():
    global sd_vae_approx_model
    model_name = "vaeapprox-sdxl.pt" if getattr(shared.sd_model, 'is_sdxl', False) else "model.pt"
    loaded_model = sd_vae_approx_models.get(model_name)

    if sd_vae_approx_model is None:
        model_path = os.path.join(paths.models_path, "VAE-approx", "model.pt")
        sd_vae_approx_model = VAEApprox()
    if loaded_model is None:
        model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
        if not os.path.exists(model_path):
            model_path = os.path.join(paths.script_path, "models", "VAE-approx", "model.pt")
        sd_vae_approx_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
        sd_vae_approx_model.eval()
        sd_vae_approx_model.to(devices.device, devices.dtype)
            model_path = os.path.join(paths.script_path, "models", "VAE-approx", model_name)

        if not os.path.exists(model_path):
            model_path = os.path.join(paths.models_path, "VAE-approx", model_name)
            download_model(model_path, 'https://github.com/AUTOMATIC1111/stable-diffusion-webui/releases/download/v1.0.0-pre/' + model_name)

        loaded_model = VAEApprox()
        loaded_model.load_state_dict(torch.load(model_path, map_location='cpu' if devices.device.type != 'cuda' else None))
        loaded_model.eval()
        loaded_model.to(devices.device, devices.dtype)
        sd_vae_approx_models[model_name] = loaded_model

    return sd_vae_approx_model
    return loaded_model


def cheap_approximation(sample):
+13 −13
Original line number Diff line number Diff line
@@ -8,9 +8,9 @@ import os
import torch
import torch.nn as nn

from modules import devices, paths_internal
from modules import devices, paths_internal, shared

sd_vae_taesd = None
sd_vae_taesd_models = {}


def conv(n_in, n_out, **kwargs):
@@ -61,9 +61,7 @@ class TAESD(nn.Module):
        return x.sub(TAESD.latent_shift).mul(2 * TAESD.latent_magnitude)


def download_model(model_path):
    model_url = 'https://github.com/madebyollin/taesd/raw/main/taesd_decoder.pth'

def download_model(model_path, model_url):
    if not os.path.exists(model_path):
        os.makedirs(os.path.dirname(model_path), exist_ok=True)

@@ -72,17 +70,19 @@ def download_model(model_path):


def model():
    global sd_vae_taesd
    model_name = "taesdxl_decoder.pth" if getattr(shared.sd_model, 'is_sdxl', False) else "taesd_decoder.pth"
    loaded_model = sd_vae_taesd_models.get(model_name)

    if sd_vae_taesd is None:
        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", "taesd_decoder.pth")
        download_model(model_path)
    if loaded_model is None:
        model_path = os.path.join(paths_internal.models_path, "VAE-taesd", model_name)
        download_model(model_path, 'https://github.com/madebyollin/taesd/raw/main/' + model_name)

        if os.path.exists(model_path):
            sd_vae_taesd = TAESD(model_path)
            sd_vae_taesd.eval()
            sd_vae_taesd.to(devices.device, devices.dtype)
            loaded_model = TAESD(model_path)
            loaded_model.eval()
            loaded_model.to(devices.device, devices.dtype)
            sd_vae_taesd_models[model_name] = loaded_model
        else:
            raise FileNotFoundError('TAESD model not found')

    return sd_vae_taesd.decoder
    return loaded_model.decoder