Unverified Commit fd4461d4 authored by AUTOMATIC1111's avatar AUTOMATIC1111 Committed by GitHub
Browse files

Merge pull request #6196 from philpax/add-embeddings-api

feat(api): add /sdapi/v1/embeddings
parents f39a79d1 c65909ad
Loading
Loading
Loading
Loading
+21 −0
Original line number Diff line number Diff line
@@ -100,6 +100,7 @@ class Api:
        self.add_api_route("/sdapi/v1/prompt-styles", self.get_prompt_styles, methods=["GET"], response_model=List[PromptStyleItem])
        self.add_api_route("/sdapi/v1/artist-categories", self.get_artists_categories, methods=["GET"], response_model=List[str])
        self.add_api_route("/sdapi/v1/artists", self.get_artists, methods=["GET"], response_model=List[ArtistItem])
        self.add_api_route("/sdapi/v1/embeddings", self.get_embeddings, methods=["GET"], response_model=EmbeddingsResponse)
        self.add_api_route("/sdapi/v1/refresh-checkpoints", self.refresh_checkpoints, methods=["POST"])
        self.add_api_route("/sdapi/v1/create/embedding", self.create_embedding, methods=["POST"], response_model=CreateResponse)
        self.add_api_route("/sdapi/v1/create/hypernetwork", self.create_hypernetwork, methods=["POST"], response_model=CreateResponse)
@@ -327,6 +328,26 @@ class Api:
    def get_artists(self):
        return [{"name":x[0], "score":x[1], "category":x[2]} for x in shared.artist_db.artists]

    def get_embeddings(self):
        db = sd_hijack.model_hijack.embedding_db

        def convert_embedding(embedding):
            return {
                "step": embedding.step,
                "sd_checkpoint": embedding.sd_checkpoint,
                "sd_checkpoint_name": embedding.sd_checkpoint_name,
                "shape": embedding.shape,
                "vectors": embedding.vectors,
            }

        def convert_embeddings(embeddings):
            return {embedding.name: convert_embedding(embedding) for embedding in embeddings.values()}

        return {
            "loaded": convert_embeddings(db.word_embeddings),
            "skipped": convert_embeddings(db.skipped_embeddings),
        }

    def refresh_checkpoints(self):
        shared.refresh_checkpoints()

+10 −0
Original line number Diff line number Diff line
@@ -249,3 +249,13 @@ class ArtistItem(BaseModel):
    score: float = Field(title="Score")
    category: str = Field(title="Category")

class EmbeddingItem(BaseModel):
    step: Optional[int] = Field(title="Step", description="The number of steps that were used to train this embedding, if available")
    sd_checkpoint: Optional[str] = Field(title="SD Checkpoint", description="The hash of the checkpoint this embedding was trained on, if available")
    sd_checkpoint_name: Optional[str] = Field(title="SD Checkpoint Name", description="The name of the checkpoint this embedding was trained on, if available. Note that this is the name that was used by the trainer; for a stable identifier, use `sd_checkpoint` instead")
    shape: int = Field(title="Shape", description="The length of each individual vector in the embedding")
    vectors: int = Field(title="Vectors", description="The number of vectors in the embedding")

class EmbeddingsResponse(BaseModel):
    loaded: Dict[str, EmbeddingItem] = Field(title="Loaded", description="Embeddings loaded for the current model")
    skipped: Dict[str, EmbeddingItem] = Field(title="Skipped", description="Embeddings skipped for the current model (likely due to architecture incompatibility)")
 No newline at end of file
+4 −4
Original line number Diff line number Diff line
@@ -59,7 +59,7 @@ class EmbeddingDatabase:
    def __init__(self, embeddings_dir):
        self.ids_lookup = {}
        self.word_embeddings = {}
        self.skipped_embeddings = []
        self.skipped_embeddings = {}
        self.dir_mtime = None
        self.embeddings_dir = embeddings_dir
        self.expected_shape = -1
@@ -91,7 +91,7 @@ class EmbeddingDatabase:
        self.dir_mtime = mt
        self.ids_lookup.clear()
        self.word_embeddings.clear()
        self.skipped_embeddings = []
        self.skipped_embeddings.clear()
        self.expected_shape = self.get_expected_shape()

        def process_file(path, filename):
@@ -136,7 +136,7 @@ class EmbeddingDatabase:
            if self.expected_shape == -1 or self.expected_shape == embedding.shape:
                self.register_embedding(embedding, shared.sd_model)
            else:
                self.skipped_embeddings.append(name)
                self.skipped_embeddings[name] = embedding

        for fn in os.listdir(self.embeddings_dir):
            try:
@@ -153,7 +153,7 @@ class EmbeddingDatabase:

        print(f"Textual inversion embeddings loaded({len(self.word_embeddings)}): {', '.join(self.word_embeddings.keys())}")
        if len(self.skipped_embeddings) > 0:
            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings)}")
            print(f"Textual inversion embeddings skipped({len(self.skipped_embeddings)}): {', '.join(self.skipped_embeddings.keys())}")

    def find_embedding_at_position(self, tokens, offset):
        token = tokens[offset]