Commit 888b928f authored by AUTOMATIC1111's avatar AUTOMATIC1111
Browse files

Merge pull request #14276 from AUTOMATIC1111/fix-styles

Fix styles
parent b55f09c4
Loading
Loading
Loading
Loading
+7 −24
Original line number Diff line number Diff line
@@ -98,10 +98,8 @@ class StyleDatabase:
        self.path = path

        folder, file = os.path.split(self.path)
        self.default_file = file.split("*")[0] + ".csv"
        if self.default_file == ".csv":
            self.default_file = "styles.csv"
        self.default_path = os.path.join(folder, self.default_file)
        filename, _, ext = file.partition('*')
        self.default_path = os.path.join(folder, filename + ext)

        self.prompt_fields = [field for field in PromptStyle._fields if field != "path"]

@@ -155,10 +153,8 @@ class StyleDatabase:
                    row["name"], prompt, negative_prompt, path
                )

    def get_style_paths(self) -> list():
        """
        Returns a list of all distinct paths, including the default path, of
        files that styles are loaded from."""
    def get_style_paths(self) -> set:
        """Returns a set of all distinct paths of files that styles are loaded from."""
        # Update any styles without a path to the default path
        for style in list(self.styles.values()):
            if not style.path:
@@ -172,9 +168,9 @@ class StyleDatabase:
                style_paths.add(style.path)

        # Remove any paths for styles that are just list dividers
        style_paths.remove("do_not_save")
        style_paths.discard("do_not_save")

        return list(style_paths)
        return style_paths

    def get_style_prompts(self, styles):
        return [self.styles.get(x, self.no_style).prompt for x in styles]
@@ -196,20 +192,7 @@ class StyleDatabase:
        # The path argument is deprecated, but kept for backwards compatibility
        _ = path

        # Update any styles without a path to the default path
        for style in list(self.styles.values()):
            if not style.path:
                self.styles[style.name] = style._replace(path=self.default_path)

        # Create a list of all distinct paths, including the default path
        style_paths = set()
        style_paths.add(self.default_path)
        for _, style in self.styles.items():
            if style.path:
                style_paths.add(style.path)

        # Remove any paths for styles that are just list dividers
        style_paths.remove("do_not_save")
        style_paths = self.get_style_paths()

        csv_names = [os.path.split(path)[1].lower() for path in style_paths]