| from typing import List | |
| from openai import OpenAI | |
| from mcp_server_mariadb_vector.embeddings.base import EmbeddingProvider | |
| class OpenAIEmbeddingProvider(EmbeddingProvider): | |
| """ | |
| OpenAI implementation of the embedding provider. | |
| Args: | |
| model: The name of the OpenAI model to use. | |
| """ | |
| def __init__(self, model: str, api_key: str): | |
| self.model = model | |
| self.client = OpenAI(api_key=api_key) | |
| def length_of_embedding(self) -> int: | |
| """Get the length of the embedding for a given model.""" | |
| if self.model == "text-embedding-3-small": | |
| return 1536 | |
| elif self.model == "text-embedding-3-large": | |
| return 3072 | |
| else: | |
| raise ValueError(f"Unknown embedding model: {self.model}") | |
| def embed_documents(self, documents: List[str]) -> List[List[float]]: | |
| """Embed a list of documents into vectors.""" | |
| embeddings = [ | |
| self.client.embeddings.create( | |
| model=self.model, | |
| input=document, | |
| ) | |
| .data[0] | |
| .embedding | |
| for document in documents | |
| ] | |
| return embeddings | |
| def embed_query(self, query: str) -> List[float]: | |
| """Embed a query into a vector.""" | |
| embedding = self.client.embeddings.create( | |
| model=self.model, | |
| input=query, | |
| ) | |
| return embedding.data[0].embedding | |