Unverified Commit 04a561c1 authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

add option to skip interrogate categories

parent efa7287b
Loading
Loading
Loading
Loading
+18 −14
Original line number Diff line number Diff line
@@ -2,6 +2,7 @@ import os
import sys
import traceback
from collections import namedtuple
from pathlib import Path
import re

import torch
@@ -20,12 +21,16 @@ Category = namedtuple("Category", ["name", "topn", "items"])

re_topn = re.compile(r"\.top(\d+)\.")

category_types = ["artists", "flavors", "mediums", "movements"]
def category_types():
    return [f.stem for f in Path(shared.interrogator.content_dir).glob('*.txt')]


def download_default_clip_interrogate_categories(content_dir):
    print("Downloading CLIP categories...")

    tmpdir = content_dir + "_tmp"
    category_types = ["artists", "flavors", "mediums", "movements"]

    try:
        os.makedirs(tmpdir)
        for category_type in category_types:
@@ -48,33 +53,32 @@ class InterrogateModels:

    def __init__(self, content_dir):
        self.loaded_categories = None
        self.selected_categories = []
        self.skip_categories = []
        self.content_dir = content_dir
        self.running_on_cpu = devices.device_interrogate == torch.device("cpu")

    def categories(self):
        if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories:
        if not os.path.exists(self.content_dir):
            download_default_clip_interrogate_categories(self.content_dir)

        if self.loaded_categories is not None and self.skip_categories == shared.opts.interrogate_clip_skip_categories:
           return self.loaded_categories

        self.loaded_categories = []

        if not os.path.exists(self.content_dir):
            download_default_clip_interrogate_categories(self.content_dir)

        if os.path.exists(self.content_dir):
            self.selected_categories = shared.opts.interrogate_clip_categories
            for category_type in category_types:
                if 'all' not in self.selected_categories and category_type not in self.selected_categories:
                    continue
                filename = os.path.join(self.content_dir, f"{category_type}.txt")
                if not os.path.isfile(filename):
            self.skip_categories = shared.opts.interrogate_clip_skip_categories
            category_types = []
            for filename in Path(self.content_dir).glob('*.txt'):
                category_types.append(filename.stem)
                if filename.stem in self.skip_categories:
                    continue
                m = re_topn.search(filename)
                m = re_topn.search(filename.stem)
                topn = 1 if m is None else int(m.group(1))
                with open(filename, "r", encoding="utf8") as file:
                    lines = [x.strip() for x in file.readlines()]

                self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines))
                self.loaded_categories.append(Category(name=filename.stem, topn=topn, items=lines))

        return self.loaded_categories

+1 −1
Original line number Diff line number Diff line
@@ -424,7 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
    "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
    "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
    "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
    "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}),
    "interrogate_clip_skip_categories": OptionInfo([], "CLIP: skip inquire categories", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types()}, refresh=modules.interrogate.category_types),
    "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
    "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
    "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),