Commit d6fdfde9 authored by unknown's avatar unknown
Browse files

Merge branch 'master' of github.com:AUTOMATIC1111/stable-diffusion-webui

parents 4005cd66 685f9631
Loading
Loading
Loading
Loading
+34 −15
Original line number Diff line number Diff line
@@ -11,25 +11,41 @@ from omegaconf import OmegaConf

from ldm.models.diffusion.ddim import DDIMSampler
from ldm.util import instantiate_from_config, ismap
from modules import shared, sd_hijack

warnings.filterwarnings("ignore", category=UserWarning)

cached_ldsr_model: torch.nn.Module = None


# Create LDSR Class
class LDSR:
    def load_model_from_config(self, half_attention):
        global cached_ldsr_model

        if shared.opts.ldsr_cached and cached_ldsr_model is not None:
            print(f"Loading model from cache")
            model: torch.nn.Module = cached_ldsr_model
        else:
            print(f"Loading model from {self.modelPath}")
            pl_sd = torch.load(self.modelPath, map_location="cpu")
            sd = pl_sd["state_dict"]
            config = OmegaConf.load(self.yamlPath)
            config.model.target = "ldm.models.diffusion.ddpm.LatentDiffusionV1"
        model = instantiate_from_config(config.model)
            model: torch.nn.Module = instantiate_from_config(config.model)
            model.load_state_dict(sd, strict=False)
        model.cuda()
            model = model.to(shared.device)
            if half_attention:
                model = model.half()
            if shared.cmd_opts.opt_channelslast:
                model = model.to(memory_format=torch.channels_last)

            sd_hijack.model_hijack.hijack(model) # apply optimization
            model.eval()

            if shared.opts.ldsr_cached:
                cached_ldsr_model = model

        return {"model": model}

    def __init__(self, model_path, yaml_path):
@@ -94,6 +110,7 @@ class LDSR:
        down_sample_method = 'Lanczos'

        gc.collect()
        if torch.cuda.is_available:
            torch.cuda.empty_cache()

        im_og = image
@@ -131,7 +148,9 @@ class LDSR:

        del model
        gc.collect()
        if torch.cuda.is_available:
            torch.cuda.empty_cache()

        return a


@@ -146,7 +165,7 @@ def get_cond(selected_path):
    c = rearrange(c, '1 c h w -> 1 h w c')
    c = 2. * c - 1.

    c = c.to(torch.device("cuda"))
    c = c.to(shared.device)
    example["LR_image"] = c
    example["image"] = c_up

+1 −0
Original line number Diff line number Diff line
@@ -59,6 +59,7 @@ def on_ui_settings():
    import gradio as gr

    shared.opts.add_option("ldsr_steps", shared.OptionInfo(100, "LDSR processing steps. Lower = faster", gr.Slider, {"minimum": 1, "maximum": 200, "step": 1}, section=('upscaling', "Upscaling")))
    shared.opts.add_option("ldsr_cached", shared.OptionInfo(False, "Cache LDSR model in memory", gr.Checkbox, {"interactive": True}, section=('upscaling', "Upscaling")))


script_callbacks.on_ui_settings(on_ui_settings)
+1 −1
Original line number Diff line number Diff line
@@ -88,7 +88,7 @@ function checkBrackets(evt) {
  if(counterElt.title != '') {
    counterElt.style = 'color: #FF5555;';
  } else {
    counterElt.style = 'color: #000;';
    counterElt.style = '';
  }
}

+11 −6
Original line number Diff line number Diff line
@@ -13,13 +13,15 @@ from skimage import exposure
from typing import Any, Dict, List, Optional

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste, script_callbacks
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
import modules.face_restoration
import modules.images as images
import modules.styles
import modules.sd_models as sd_models
import modules.sd_vae as sd_vae
import logging
from ldm.data.util import AddMiDaS
from ldm.models.diffusion.ddpm import LatentDepth2ImageDiffusion
@@ -454,8 +456,10 @@ def process_images(p: StableDiffusionProcessing) -> Processed:

    try:
        for k, v in p.override_settings.items():
            setattr(opts, k, v)  # we don't call onchange for simplicity which makes changing model impossible
            if k == 'sd_hypernetwork': shared.reload_hypernetworks()  # make onchange call for changing hypernet since it is relatively fast to load on-change, while SD models are not
            setattr(opts, k, v)
            if k == 'sd_hypernetwork': shared.reload_hypernetworks()  # make onchange call for changing hypernet
            if k == 'sd_model_checkpoint': sd_models.reload_model_weights()  # make onchange call for changing SD model
            if k == 'sd_vae': sd_vae.reload_vae_weights()  # make onchange call for changing VAE

        res = process_images_inner(p)

@@ -463,6 +467,8 @@ def process_images(p: StableDiffusionProcessing) -> Processed:
        for k, v in stored_opts.items():
            setattr(opts, k, v)
            if k == 'sd_hypernetwork': shared.reload_hypernetworks()
            if k == 'sd_model_checkpoint': sd_models.reload_model_weights()
            if k == 'sd_vae': sd_vae.reload_vae_weights()

    return res

@@ -571,9 +577,8 @@ def process_images_inner(p: StableDiffusionProcessing) -> Processed:

            devices.torch_gc()

            if opts.filter_nsfw:
                import modules.safety as safety
                x_samples_ddim = modules.safety.censor_batch(x_samples_ddim)
            if p.scripts is not None:
                p.scripts.postprocess_batch(p, x_samples_ddim, batch_number=n)

            for i, x_sample in enumerate(x_samples_ddim):
                x_sample = 255. * np.moveaxis(x_sample.cpu().numpy(), 0, 2)

modules/safety.py

deleted100644 → 0
+0 −42
Original line number Diff line number Diff line
import torch
from diffusers.pipelines.stable_diffusion.safety_checker import StableDiffusionSafetyChecker
from transformers import AutoFeatureExtractor
from PIL import Image

import modules.shared as shared

safety_model_id = "CompVis/stable-diffusion-safety-checker"
safety_feature_extractor = None
safety_checker = None

def numpy_to_pil(images):
    """
    Convert a numpy image or a batch of images to a PIL image.
    """
    if images.ndim == 3:
        images = images[None, ...]
    images = (images * 255).round().astype("uint8")
    pil_images = [Image.fromarray(image) for image in images]

    return pil_images

# check and replace nsfw content
def check_safety(x_image):
    global safety_feature_extractor, safety_checker

    if safety_feature_extractor is None:
        safety_feature_extractor = AutoFeatureExtractor.from_pretrained(safety_model_id)
        safety_checker = StableDiffusionSafetyChecker.from_pretrained(safety_model_id)

    safety_checker_input = safety_feature_extractor(numpy_to_pil(x_image), return_tensors="pt")
    x_checked_image, has_nsfw_concept = safety_checker(images=x_image, clip_input=safety_checker_input.pixel_values)

    return x_checked_image, has_nsfw_concept


def censor_batch(x):
    x_samples_ddim_numpy = x.cpu().permute(0, 2, 3, 1).numpy()
    x_checked_image, has_nsfw_concept = check_safety(x_samples_ddim_numpy)
    x = torch.from_numpy(x_checked_image).permute(0, 3, 1, 2)

    return x
Loading