Unverified Commit 925dd09c authored by Vladimir Mandic's avatar Vladimir Mandic Committed by GitHub
Browse files

improve interrogate

parent 59146621
Loading
Loading
Loading
Loading
+17 −12
Original line number Diff line number Diff line
@@ -20,6 +20,7 @@ Category = namedtuple("Category", ["name", "topn", "items"])

re_topn = re.compile(r"\.top(\d+)\.")

category_types = ["artists", "flavors", "mediums", "movements"]

def download_default_clip_interrogate_categories(content_dir):
    print("Downloading CLIP categories...")
@@ -27,12 +28,8 @@ def download_default_clip_interrogate_categories(content_dir):
    tmpdir = content_dir + "_tmp"
    try:
        os.makedirs(tmpdir)

        torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/artists.txt", os.path.join(tmpdir, "artists.txt"))
        torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/flavors.txt", os.path.join(tmpdir, "flavors.top3.txt"))
        torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/mediums.txt", os.path.join(tmpdir, "mediums.txt"))
        torch.hub.download_url_to_file("https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/movements.txt", os.path.join(tmpdir, "movements.txt"))

        for category_type in category_types:
            torch.hub.download_url_to_file(f"https://raw.githubusercontent.com/pharmapsychotic/clip-interrogator/main/clip_interrogator/data/{category_type}.txt", os.path.join(tmpdir, f"{category_type}.txt"))
        os.rename(tmpdir, content_dir)

    except Exception as e:
@@ -51,11 +48,12 @@ class InterrogateModels:

    def __init__(self, content_dir):
        self.loaded_categories = None
        self.selected_categories = []
        self.content_dir = content_dir
        self.running_on_cpu = devices.device_interrogate == torch.device("cpu")

    def categories(self):
        if self.loaded_categories is not None:
        if self.loaded_categories is not None and self.selected_categories == shared.opts.interrogate_clip_categories:
           return self.loaded_categories

        self.loaded_categories = []
@@ -64,14 +62,19 @@ class InterrogateModels:
            download_default_clip_interrogate_categories(self.content_dir)

        if os.path.exists(self.content_dir):
            for filename in os.listdir(self.content_dir):
            self.selected_categories = shared.opts.interrogate_clip_categories
            for category_type in category_types:
                if 'all' not in self.selected_categories and category_type not in self.selected_categories:
                    continue
                filename = os.path.join(self.content_dir, f"{category_type}.txt")
                if not os.path.isfile(filename):
                    continue
                m = re_topn.search(filename)
                topn = 1 if m is None else int(m.group(1))

                with open(os.path.join(self.content_dir, filename), "r", encoding="utf8") as file:
                with open(filename, "r", encoding="utf8") as file:
                    lines = [x.strip() for x in file.readlines()]

                self.loaded_categories.append(Category(name=filename, topn=topn, items=lines))
                self.loaded_categories.append(Category(name=category_type, topn=topn, items=lines))

        return self.loaded_categories

@@ -139,6 +142,8 @@ class InterrogateModels:
    def rank(self, image_features, text_array, top_count=1):
        import clip

        devices.torch_gc()

        if shared.opts.interrogate_clip_dict_limit != 0:
            text_array = text_array[0:int(shared.opts.interrogate_clip_dict_limit)]

+1 −0
Original line number Diff line number Diff line
@@ -424,6 +424,7 @@ options_templates.update(options_section(('interrogate', "Interrogate Options"),
    "interrogate_clip_min_length": OptionInfo(24, "Interrogate: minimum description length (excluding artists, etc..)", gr.Slider, {"minimum": 1, "maximum": 128, "step": 1}),
    "interrogate_clip_max_length": OptionInfo(48, "Interrogate: maximum description length", gr.Slider, {"minimum": 1, "maximum": 256, "step": 1}),
    "interrogate_clip_dict_limit": OptionInfo(1500, "CLIP: maximum number of lines in text file (0 = No limit)"),
    "interrogate_clip_categories": OptionInfo(modules.interrogate.category_types, "CLIP: select which categories to inquire", gr.CheckboxGroup, lambda: {"choices": modules.interrogate.category_types}),
    "interrogate_deepbooru_score_threshold": OptionInfo(0.5, "Interrogate: deepbooru score threshold", gr.Slider, {"minimum": 0, "maximum": 1, "step": 0.01}),
    "deepbooru_sort_alpha": OptionInfo(True, "Interrogate: deepbooru sort alphabetically"),
    "deepbooru_use_spaces": OptionInfo(False, "use spaces for tags in deepbooru"),