Commit df570640 authored by AUTOMATIC's avatar AUTOMATIC
Browse files

do not load aesthetic clip model until it's needed

add refresh button for aesthetic embeddings
add aesthetic params to images' infotext
parent 7d6b388d
Loading
Loading
Loading
Loading
+33 −7
Original line number Diff line number Diff line
@@ -40,6 +40,8 @@ def iter_to_batched(iterable, n=1):


def create_ui():
    import modules.ui

    with gr.Group():
        with gr.Accordion("Open for Clip Aesthetic!", open=False):
            with gr.Row():
@@ -55,6 +57,8 @@ def create_ui():
                                             label="Aesthetic imgs embedding",
                                             value="None")

                modules.ui.create_refresh_button(aesthetic_imgs, shared.update_aesthetic_embeddings, lambda: {"choices": sorted(shared.aesthetic_embeddings.keys())}, "refresh_aesthetic_embeddings")

            with gr.Row():
                aesthetic_imgs_text = gr.Textbox(label='Aesthetic text for imgs',
                                                 placeholder="This text is used to rotate the feature space of the imgs embs",
@@ -66,11 +70,21 @@ def create_ui():
    return aesthetic_weight, aesthetic_steps, aesthetic_lr, aesthetic_slerp, aesthetic_imgs, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative


aesthetic_clip_model = None


def aesthetic_clip():
    global aesthetic_clip_model

    if aesthetic_clip_model is None or aesthetic_clip_model.name_or_path != shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path:
        aesthetic_clip_model = CLIPModel.from_pretrained(shared.sd_model.cond_stage_model.wrapped.transformer.name_or_path)
        aesthetic_clip_model.cpu()

    return aesthetic_clip_model


def generate_imgs_embd(name, folder, batch_size):
    # clipModel = CLIPModel.from_pretrained(
    #     shared.sd_model.cond_stage_model.clipModel.name_or_path
    # )
    model = shared.clip_model.to(device)
    model = aesthetic_clip().to(device)
    processor = CLIPProcessor.from_pretrained(model.name_or_path)

    with torch.no_grad():
@@ -91,7 +105,7 @@ def generate_imgs_embd(name, folder, batch_size):
        path = str(Path(shared.cmd_opts.aesthetic_embeddings_dir) / f"{name}.pt")
        torch.save(embs, path)

        model = model.cpu()
        model.cpu()
        del processor
        del embs
        gc.collect()
@@ -132,7 +146,7 @@ class AestheticCLIP:
        self.image_embs = None
        self.load_image_embs(None)

    def set_aesthetic_params(self, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
    def set_aesthetic_params(self, p, aesthetic_lr=0, aesthetic_weight=0, aesthetic_steps=0, image_embs_name=None,
                             aesthetic_slerp=True, aesthetic_imgs_text="",
                             aesthetic_slerp_angle=0.15,
                             aesthetic_text_negative=False):
@@ -145,6 +159,18 @@ class AestheticCLIP:
        self.aesthetic_steps = aesthetic_steps
        self.load_image_embs(image_embs_name)

        if self.image_embs_name is not None:
            p.extra_generation_params.update({
                "Aesthetic LR": aesthetic_lr,
                "Aesthetic weight": aesthetic_weight,
                "Aesthetic steps": aesthetic_steps,
                "Aesthetic embedding": self.image_embs_name,
                "Aesthetic slerp": aesthetic_slerp,
                "Aesthetic text": aesthetic_imgs_text,
                "Aesthetic text negative": aesthetic_text_negative,
                "Aesthetic slerp angle": aesthetic_slerp_angle,
            })

    def set_skip(self, skip):
        self.skip = skip

@@ -168,7 +194,7 @@ class AestheticCLIP:

            tokens = torch.asarray(remade_batch_tokens).to(device)

            model = copy.deepcopy(shared.clip_model).to(device)
            model = copy.deepcopy(aesthetic_clip()).to(device)
            model.requires_grad_(True)
            if self.aesthetic_imgs_text is not None and len(self.aesthetic_imgs_text) > 0:
                text_embs_2 = model.get_text_features(
+16 −2
Original line number Diff line number Diff line
@@ -4,13 +4,22 @@ import gradio as gr
from modules.shared import script_path
from modules import shared

re_param_code = r"\s*([\w ]+):\s*([^,]+)(?:,|$)"
re_param_code = r'\s*([\w ]+):\s*("(?:\\|\"|[^\"])+"|[^,]*)(?:,|$)'
re_param = re.compile(re_param_code)
re_params = re.compile(r"^(?:" + re_param_code + "){3,}$")
re_imagesize = re.compile(r"^(\d+)x(\d+)$")
type_of_gr_update = type(gr.update())


def quote(text):
    if ',' not in str(text):
        return text

    text = str(text)
    text = text.replace('\\', '\\\\')
    text = text.replace('"', '\\"')
    return f'"{text}"'

def parse_generation_parameters(x: str):
    """parses generation parameters string, the one you see in text field under the picture in UI:
```
@@ -83,7 +92,12 @@ def connect_paste(button, paste_fields, input_comp, js=None):
            else:
                try:
                    valtype = type(output.value)

                    if valtype == bool and v == "False":
                        val = False
                    else:
                        val = valtype(v)

                    res.append(gr.update(value=val))
                except Exception:
                    res.append(gr.update())
+1 −4
Original line number Diff line number Diff line
@@ -109,10 +109,7 @@ def img2img(mode: int, prompt: str, negative_prompt: str, prompt_style: str, pro
        inpainting_mask_invert=inpainting_mask_invert,
    )

    shared.aesthetic_clip.set_aesthetic_params(float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps),
                                               aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text,
                                               aesthetic_slerp_angle,
                                               aesthetic_text_negative)
    shared.aesthetic_clip.set_aesthetic_params(p, float(aesthetic_lr), float(aesthetic_weight), int(aesthetic_steps), aesthetic_imgs, aesthetic_slerp, aesthetic_imgs_text, aesthetic_slerp_angle, aesthetic_text_negative)

    if shared.cmd_opts.enable_console_prompts:
        print(f"\nimg2img: {prompt}", file=shared.progress_print_out)
+2 −2
Original line number Diff line number Diff line
@@ -12,7 +12,7 @@ from skimage import exposure
from typing import Any, Dict, List, Optional

import modules.sd_hijack
from modules import devices, prompt_parser, masking, sd_samplers, lowvram
from modules import devices, prompt_parser, masking, sd_samplers, lowvram, generation_parameters_copypaste
from modules.sd_hijack import model_hijack
from modules.shared import opts, cmd_opts, state
import modules.shared as shared
@@ -318,7 +318,7 @@ def create_infotext(p, all_prompts, all_seeds, all_subseeds, comments, iteration

    generation_params.update(p.extra_generation_params)

    generation_params_text = ", ".join([k if k == v else f'{k}: {v}' for k, v in generation_params.items() if v is not None])
    generation_params_text = ", ".join([k if k == v else f'{k}: {generation_parameters_copypaste.quote(v)}' for k, v in generation_params.items() if v is not None])

    negative_prompt_text = "\nNegative prompt: " + p.negative_prompt if p.negative_prompt else ""

+0 −3
Original line number Diff line number Diff line
@@ -234,9 +234,6 @@ def load_model(checkpoint_info=None):

    sd_hijack.model_hijack.hijack(sd_model)

    if shared.clip_model is None or shared.clip_model.transformer.name_or_path != sd_model.cond_stage_model.wrapped.transformer.name_or_path:
        shared.clip_model = CLIPModel.from_pretrained(sd_model.cond_stage_model.wrapped.transformer.name_or_path)

    sd_model.eval()

    print(f"Model loaded.")
Loading