Commit 7d6b388d authored by AUTOMATIC's avatar AUTOMATIC
Browse files

Merge branch 'ae'

parents bf30673f 2362d5f0
Loading
Loading
Loading
Loading
+2 −0
Original line number Diff line number Diff line
@@ -71,6 +71,8 @@ Check the [custom scripts](https://github.com/AUTOMATIC1111/stable-diffusion-web
- DeepDanbooru integration, creates danbooru style tags for anime prompts (add --deepdanbooru to commandline args)
- [xformers](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Xformers), major speed increase for select cards: (add --xformers to commandline args)
- Support for dedicated [inpainting model](https://github.com/runwayml/stable-diffusion#inpainting-with-stable-diffusion) by RunwayML. 
- Aesthetic Gradients, a way to generate images with a specific aesthetic by using clip images embds (implementation of [https://github.com/vicgalle/stable-diffusion-aesthetic-gradients](https://github.com/vicgalle/stable-diffusion-aesthetic-gradients))


## Installation and Running
Make sure the required [dependencies](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Dependencies) are met and follow the instructions available for both [NVidia](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-NVidia-GPUs) (recommended) and [AMD](https://github.com/AUTOMATIC1111/stable-diffusion-webui/wiki/Install-and-Run-on-AMD-GPUs) GPUs.
+215 −0
Original line number Diff line number Diff line
import copy
import itertools
import os
from pathlib import Path
import html
import gc

import gradio as gr
import torch
from PIL import Image
from torch import optim

from modules import shared
from transformers import CLIPModel, CLIPProcessor, CLIPTokenizer
from tqdm.auto import tqdm, trange
from modules.shared import opts, device


def get_all_images_in_folder(folder):
    return [os.path.join(folder, f) for f in os.listdir(folder) if
            os.path.isfile(os.path.join(folder, f)) and check_is_valid_image_file(f)]


def check_is_valid_image_file(filename):
    return filename.lower().endswith(('.png', '.jpg', '.jpeg', ".gif", ".tiff", ".webp"))


def batched(dataset, total, n=1):
    for ndx in range(0, total, n):
        yield [dataset.__getitem__(i) for i in range(ndx, min(ndx + n, total))]


def iter_to_batched(iterable, n=1):
    it = iter(iterable)
    while True:
        chunk = tuple(itertools.islice(it, n))
        if not chunk:
            return
        yield chunk


def create_ui():
    with gr.Group():
        with gr.Accordion("Open for Clip Aesthetic!", open=False):
            with gr.Row():
                aesthetic_weight = gr.Slider(minimum=0, maximum=1, step=0.01, label="Aesthetic weight",
                                             value=0.9)
                aesthetic_steps = gr.Slider(minimum=0, maximum=50, step=1, label="Aesthetic steps", value=5)

            with gr.Row():
                aesthetic_lr = gr.Textbox(label='Aesthetic learning rate',
                                          placeholder="Aesthetic learning rate", value="0.0001")
                aesthetic_slerp = gr.Checkbox(label="Slerp interpolation", value=False)
                aesthetic_imgs = gr.Dropdown(sorted(shared.aesthetic_embeddings.keys()),
                                             label="Aesthetic imgs embedding",
                                             value="None")

            with gr.Row():
                aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
                                                 placeholder="This text is used to rotate the feature space of the imgs embs",
                                                 value="")
                aesthetic_slerp_angle = gr.Slider(label='Slerp angle', minimum=0, maximum=1, step=0.01,
                                                  value=0.1)
                aesthetic_text_negative = gr.Checkbox(label="Is negative text", value=False)

    return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative


def generate_imgs_embd(name, folder, batch_size):
    # clipModel = CLIPModel.from_pretrained(
    #     shared.sd_model.cond_stage_model.clipModel.name_or_path
    # )
    model = shared.clip_model.to(device)
    processor = CLIPProcessor.from_pretrained(model.name_or_path)

    with torch.no_grad():
        embs = []
        for paths in tqdm(iter_to_batched(get_all_images_in_folder(folder), batch_size),
                          desc=f"Generating embeddings for {name}"):
            if shared.state.interrupted:
                break
            inputs = processor(images=[Image.open(path) for path in paths], return_tensors="pt").to(device)
            outputs = model.get_image_features(**inputs).cpu()
            embs.append(torch.clone(outputs))
            inputs.to("cpu")
            del inputs, outputs

        embs = torch.cat(embs, dim=0).mean(dim=0, keepdim=True)

        # The generated embedding will be located here
        path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
        torch.save(embs, path)

        model = model.cpu()
        del processor
        del embs
        gc.collect()
        torch.cuda.empty_cache()
        res = f"""
        Done generating embedding for {name}!
        Aesthetic embedding saved to {html.escape(path)}
        """
        shared.update_aesthetic_embeddings()
        return gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()), label="Imgs embedding",
                                  value="None"), \
               gr.Dropdown.update(choices=sorted(shared.aesthetic_embeddings.keys()),
                                  label="Imgs embedding",
                                  value="None"), res, ""


def slerp(low, high, val):
    low_norm = low / torch.norm(low, dim=1, keepdim=True)
    high_norm = high / torch.norm(high, dim=1, keepdim=True)
    omega = torch.acos((low_norm * high_norm).sum(1))
    so = torch.sin(omega)
    res = (torch.sin((1.0 - val) * omega) / so).unsqueeze(1) * low + (torch.sin(val * omega) / so).unsqueeze(1) * high
    return res


class AestheticCLIP:
    def __init__(self):
        self.skip = False
        self.aesthetic_steps = 0
        self.aesthetic_weight = 0
        self.aesthetic_lr = 0
        self.slerp = False
        self.aesthetic_text_negative = ""
        self.aesthetic_slerp_angle = 0
        self.aesthetic_imgs_text = ""

        self.image_embs_name = None
        self.image_embs = None
        self.load_image_embs(None)

    def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
                             aesthetic_slerp=True, aesthetic_imgs_text="",
                             aesthetic_slerp_angle=0.15,
                             aesthetic_text_negative=False):
        self.aesthetic_imgs_text = aesthetic_imgs_text
        self.aesthetic_slerp_angle = aesthetic_slerp_angle
        self.aesthetic_text_negative = aesthetic_text_negative
        self.slerp = aesthetic_slerp
        self.aesthetic_lr = aesthetic_lr
        self.aesthetic_weight = aesthetic_weight
        self.aesthetic_steps = aesthetic_steps
        self.load_image_embs(image_embs_name)

    def set_skip(self, skip):
        self.skip = skip

    def load_image_embs(self, image_embs_name):
        if image_embs_name is None or len(image_embs_name) == 0 or image_embs_name == "None":
            image_embs_name = None
            self.image_embs_name = None
        if image_embs_name is not None and self.image_embs_name != image_embs_name:
            self.image_embs_name = image_embs_name
            self.image_embs = torch.load(shared.aesthetic_embeddings[self.image_embs_name], map_location=device)
            self.image_embs /= self.image_embs.norm(dim=-1, keepdim=True)
            self.image_embs.requires_grad_(False)

    def __call__(self, z, remade_batch_tokens):
        if not self.skip and self.aesthetic_steps != 0 and self.aesthetic_lr != 0 and self.aesthetic_weight != 0 and self.image_embs_name is not None:
            tokenizer = shared.sd_model.cond_stage_model.tokenizer
            if not opts.use_old_emphasis_implementation:
                remade_batch_tokens = [
                    [tokenizer.bos_token_id] + x[:75] + [tokenizer.eos_token_id] for x in
                    remade_batch_tokens]

            tokens = torch.asarray(remade_batch_tokens).to(device)

            model = copy.deepcopy(shared.clip_model).to(device)
            model.requires_grad_(True)
            if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
                text_embs_2 = model.get_text_features(
                    **tokenizer([self.aesthetic_imgs_text], padding=True, return_tensors="pt").to(device))
                if self.aesthetic_text_negative:
                    text_embs_2 = self.image_embs - text_embs_2
                    text_embs_2 /= text_embs_2.norm(dim=-1, keepdim=True)
                img_embs = slerp(self.image_embs, text_embs_2, self.aesthetic_slerp_angle)
            else:
                img_embs = self.image_embs

            with torch.enable_grad():

                # We optimize the model to maximize the similarity
                optimizer = optim.Adam(
                    model.text_model.parameters(), lr=self.aesthetic_lr
                )

                for _ in trange(self.aesthetic_steps, desc="Aesthetic optimization"):
                    text_embs = model.get_text_features(input_ids=tokens)
                    text_embs = text_embs / text_embs.norm(dim=-1, keepdim=True)
                    sim = text_embs @ img_embs.T
                    loss = -sim
                    optimizer.zero_grad()
                    loss.mean().backward()
                    optimizer.step()

                zn = model.text_model(input_ids=tokens, output_hidden_states=-opts.CLIP_stop_at_last_layers)
                if opts.CLIP_stop_at_last_layers > 1:
                    zn = zn.hidden_states[-opts.CLIP_stop_at_last_layers]
                    zn = model.text_model.final_layer_norm(zn)
                else:
                    zn = zn.last_hidden_state
                model.cpu()
                del model
                gc.collect()
                torch.cuda.empty_cache()
            zn = torch.concat([zn[77 * i:77 * (i + 1)] for i in range(max(z.shape[1] // 77, 1))], 1)
            if self.slerp:
                z = slerp(z, zn, self.aesthetic_weight)
            else:
                z = z * (1 - self.aesthetic_weight) + zn * self.aesthetic_weight

        return z
+6 −1
Original line number Diff line number Diff line
@@ -56,7 +56,7 @@ def process_batch(p, input_dir, output_dir, args):
                processed_image.save(os.path.join(output_dir, filename))


def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, *args):
def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, prompt_style2: str, init_img, init_img_with_mask, init_img_inpaint, init_mask_inpaint, mask_mode, steps: int, sampler_index: int, mask_blur: int, inpainting_fill: int, restore_faces: bool, tiling: bool, n_iter: int, batch_size: int, cfg_scale: float, denoising_strength: float, seed: int, subseed: int, subseed_strength: float, seed_resize_from_h: int, seed_resize_from_w: int, seed_enable_extras: bool, height: int, width: int, resize_mode: int, inpaint_full_res: bool, inpaint_full_res_padding: int, inpainting_mask_invert: int, img2img_batch_input_dir: str, img2img_batch_output_dir: str, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, aesthetic_imgs=None, aesthetic_slerp=False, aesthetic_imgs_text="", aesthetic_slerp_angle=0.15, aesthetic_text_negative=False, *args):
    is_inpaint = mode == 1
    is_batch = mode == 2

@@ -109,6 +109,11 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
        inpainting_mask_invert=inpainting_mask_invert,
    )

    shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
                                               aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
                                               aesthetic_slerp_angle,
                                               aesthetic_text_negative)

    if shared.cmd_opts.enable_console_prompts:
        print(f"\nimg2img: {prompt}", file=shared.progress_print_out)

+15 −15
Original line number Diff line number Diff line
@@ -19,6 +19,7 @@ attention_CrossAttention_forward = ldm.modules.attention.CrossAttention.forward
diffusionmodules_model_nonlinearity = ldm.modules.diffusionmodules.model.nonlinearity
diffusionmodules_model_AttnBlock_forward = ldm.modules.diffusionmodules.model.AttnBlock.forward


def apply_optimizations():
    undo_optimizations()

@@ -223,7 +224,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):

        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count


    def process_text_old(self, text):
        id_start = self.wrapped.tokenizer.bos_token_id
        id_end = self.wrapped.tokenizer.eos_token_id
@@ -333,6 +333,7 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):

            z1 = self.process_tokens(tokens, multipliers)
            z = z1 if z is None else torch.cat((z, z1), axis=-2)
            z = shared.aesthetic_clip(z, remade_batch_tokens)

            remade_batch_tokens = rem_tokens
            batch_multipliers = rem_multipliers
@@ -340,7 +341,6 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):

        return z

    
    def process_tokens(self, remade_batch_tokens, batch_multipliers):
        if not opts.use_old_emphasis_implementation:
            remade_batch_tokens = [[self.wrapped.tokenizer.bos_token_id] + x[:75] + [self.wrapped.tokenizer.eos_token_id] for x in remade_batch_tokens]
+4 −1
Original line number Diff line number Diff line
@@ -21,7 +21,7 @@ checkpoints_loaded = collections.OrderedDict()
try:
    # this silences the annoying "Some weights of the model checkpoint were not used when initializing..." message at start.

    from transformers import logging
    from transformers import logging, CLIPModel

    logging.set_verbosity_error()
except Exception:
@@ -234,6 +234,9 @@ def load_model(checkpoint_info=None):

    sd_hijack.model_hijack.hijack(sd_model)

    if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
        shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)

    sd_model.eval()

    print(f"Model loaded.")
Loading