Unverified Commit 278e7c71 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #1194 from liamkerr/token_count

Token count
parents 1deac2b6 7ca9858c
Loading
Loading
Loading
Loading
+19 −0
Original line number Diff line number Diff line
@@ -182,4 +182,23 @@ onUiUpdate(function(){
    });

    json_elem.parentElement.style.display="none"

	if (!txt2img_textarea) {
		txt2img_textarea = gradioApp().querySelector("#txt2img_prompt > label > textarea");
		txt2img_textarea?.addEventListener("input", () => update_token_counter("txt2img_token_button"));
	}
	if (!img2img_textarea) {
		img2img_textarea = gradioApp().querySelector("#img2img_prompt > label > textarea");
		img2img_textarea?.addEventListener("input", () => update_token_counter("img2img_token_button"));
	}
})

let txt2img_textarea, img2img_textarea = undefined;
let wait_time = 800
let token_timeout;

function update_token_counter(button_id) {
	if (token_timeout)
		clearTimeout(token_timeout);
	token_timeout = setTimeout(() => gradioApp().getElementById(button_id)?.click(), wait_time);
}
+21 −8
Original line number Diff line number Diff line
@@ -180,6 +180,7 @@ class StableDiffusionModelHijack:
    dir_mtime = None
    layers = None
    circular_enabled = False
    clip = None

    def load_textual_inversion_embeddings(self, dirname, model):
        mt = os.path.getmtime(dirname)
@@ -242,6 +243,7 @@ class StableDiffusionModelHijack:

        model_embeddings.token_embedding = EmbeddingsWithFixes(model_embeddings.token_embedding, self)
        m.cond_stage_model = FrozenCLIPEmbedderWithCustomWords(m.cond_stage_model, self)
        self.clip = m.cond_stage_model

        if cmd_opts.opt_split_attention_v1:
            ldm.modules.attention.CrossAttention.forward = split_cross_attention_forward_v1
@@ -268,6 +270,10 @@ class StableDiffusionModelHijack:
        for layer in [layer for layer in self.layers if type(layer) == torch.nn.Conv2d]:
            layer.padding_mode = 'circular' if enable else 'zeros'

    def tokenize(self, text):
        max_length = self.clip.max_length - 2
        _, remade_batch_tokens, _, _, _, token_count = self.clip.process_text([text])
        return remade_batch_tokens[0], token_count, max_length

class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
    def __init__(self, wrapped, hijack):
@@ -294,14 +300,16 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
            if mult != 1.0:
                self.token_mults[ident] = mult

    def forward(self, text):
        self.hijack.fixes = []
        self.hijack.comments = []
        remade_batch_tokens = []
    def process_text(self, text):
        id_start = self.wrapped.tokenizer.bos_token_id
        id_end = self.wrapped.tokenizer.eos_token_id
        maxlen = self.wrapped.max_length
        used_custom_terms = []
        remade_batch_tokens = []
        overflowing_words = []
        hijack_comments = []
        hijack_fixes = []
        token_count = 0

        cache = {}
        batch_tokens = self.wrapped.tokenizer(text, truncation=False, add_special_tokens=False)["input_ids"]
@@ -353,9 +361,8 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
                    ovf = remade_tokens[maxlen - 2:]
                    overflowing_words = [vocab.get(int(x), "") for x in ovf]
                    overflowing_text = self.wrapped.tokenizer.convert_tokens_to_string(''.join(overflowing_words))

                    self.hijack.comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")

                    hijack_comments.append(f"Warning: too many input tokens; some ({len(overflowing_words)}) have been truncated:\n{overflowing_text}\n")
                token_count = len(remade_tokens)
                remade_tokens = remade_tokens + [id_end] * (maxlen - 2 - len(remade_tokens))
                remade_tokens = [id_start] + remade_tokens[0:maxlen-2] + [id_end]
                cache[tuple_tokens] = (remade_tokens, fixes, multipliers)
@@ -364,8 +371,14 @@ class FrozenCLIPEmbedderWithCustomWords(torch.nn.Module):
            multipliers = [1.0] + multipliers[0:maxlen - 2] + [1.0]

            remade_batch_tokens.append(remade_tokens)
            self.hijack.fixes.append(fixes)
            hijack_fixes.append(fixes)
            batch_multipliers.append(multipliers)
        return batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count

    def forward(self, text):
        batch_multipliers, remade_batch_tokens, used_custom_terms, hijack_comments, hijack_fixes, token_count = self.process_text(text)
        self.hijack.fixes = hijack_fixes
        self.hijack.comments = hijack_comments

        if len(used_custom_terms) > 0:
            self.hijack.comments.append("Used custom terms: " + ", ".join([f'{word} [{checksum}]' for word, checksum in used_custom_terms]))
+9 −1
Original line number Diff line number Diff line
@@ -22,6 +22,7 @@ from modules.paths import script_path
from modules.shared import opts, cmd_opts
import modules.shared as shared
from modules.sd_samplers import samplers, samplers_for_img2img
from modules.sd_hijack import model_hijack
import modules.ldsr_model
import modules.scripts
import modules.gfpgan_model
@@ -333,6 +334,10 @@ def connect_reuse_seed(seed: gr.Number, reuse_seed: gr.Button, generation_info:
        outputs=[seed, dummy_component]
    )

def update_token_counter(text):
    tokens, token_count, max_length = model_hijack.tokenize(text)
    style_class = ' class="red"' if (token_count > max_length) else ""
    return f"<span {style_class}>{token_count}/{max_length}</span>"

def create_toprow(is_img2img):
    id_part = "img2img" if is_img2img else "txt2img"
@@ -342,11 +347,14 @@ def create_toprow(is_img2img):
            with gr.Row():
                with gr.Column(scale=80):
                    with gr.Row():
                        prompt = gr.Textbox(label="Prompt", elem_id="prompt", show_label=False, placeholder="Prompt", lines=2)
                        prompt = gr.Textbox(label="Prompt", elem_id=f"{id_part}_prompt", show_label=False, placeholder="Prompt", lines=2)

                with gr.Column(scale=1, elem_id="roll_col"):
                    roll = gr.Button(value=art_symbol, elem_id="roll", visible=len(shared.artist_db.artists) > 0)
                    paste = gr.Button(value=paste_symbol, elem_id="paste")
                    token_counter = gr.HTML(value="<span></span>", elem_id=f"{id_part}_token_counter")
                    hidden_button = gr.Button(visible=False, elem_id=f"{id_part}_token_button")
                    hidden_button.click(fn=update_token_counter, inputs=[prompt], outputs=[token_counter])

                with gr.Column(scale=10, elem_id="style_pos_col"):
                    prompt_style = gr.Dropdown(label="Style 1", elem_id=f"{id_part}_style_index", choices=[k for k, v in shared.prompt_styles.styles.items()], value=next(iter(shared.prompt_styles.styles.keys())), visible=len(shared.prompt_styles.styles) > 1)
+4 −0
Original line number Diff line number Diff line
@@ -389,3 +389,7 @@ input[type="range"]{
  border-radius: 8px;
  display: none;
}

.red {
	color: red;
}