Kaballas commited on
Commit
644bdfe
·
1 Parent(s): f1d67a1
Dockerfile ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM python:3.10-slim
2
+
3
+ # Install system dependencies
4
+ RUN apt-get update && apt-get install -y --no-install-recommends \
5
+ gcc \
6
+ python3-dev \
7
+ openssl \
8
+ curl \
9
+ ca-certificates \
10
+ gnupg \
11
+ build-essential && \
12
+ rm -rf /var/lib/apt/lists/*
13
+
14
+ # Set up MariaDB's Python connector dependencies
15
+ RUN curl -LsSO https://r.mariadb.com/downloads/mariadb_repo_setup && \
16
+ echo "c4a0f3dade02c51a6a28ca3609a13d7a0f8910cccbb90935a2f218454d3a914a mariadb_repo_setup" | sha256sum -c - && \
17
+ chmod +x mariadb_repo_setup && \
18
+ ./mariadb_repo_setup --mariadb-server-version="mariadb-11.7" && \
19
+ rm mariadb_repo_setup && \
20
+ apt-get update && \
21
+ apt-get install -y --no-install-recommends \
22
+ libmariadb3 \
23
+ libmariadb-dev && \
24
+ apt-get clean && \
25
+ rm -rf /var/lib/apt/lists/*
26
+
27
+ # Install uv package manager
28
+ RUN pip install --no-cache-dir uv
29
+
30
+ WORKDIR /app
31
+
32
+ # Copy project files
33
+ COPY . /app
34
+
35
+ # Install project dependencies
36
+ RUN uv sync
37
+
38
+ EXPOSE 8000
39
+
40
+ CMD ["uv", "run", "mcp-server-mariadb-vector", "--transport", "sse", "--host", "0.0.0.0"]
app.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+
3
+ app = FastAPI()
4
+
5
+ @app.get("/")
6
+ def greet_json():
7
+ return {"Hello": "World!"}
requirements.txt ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ fastapi
2
+ uvicorn[standard]
src/mcp_server_mariadb_vector/__init__.py ADDED
File without changes
src/mcp_server_mariadb_vector/app_context.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from contextlib import asynccontextmanager
2
+ from dataclasses import dataclass
3
+ from typing import AsyncIterator
4
+
5
+ import mariadb
6
+ from mcp.server.fastmcp import FastMCP
7
+
8
+ from mcp_server_mariadb_vector.settings import DatabaseSettings
9
+
10
+
11
+ @dataclass
12
+ class AppContext:
13
+ conn: mariadb.Connection
14
+
15
+
16
+ @asynccontextmanager
17
+ async def app_lifespan(server: FastMCP) -> AsyncIterator[AppContext]:
18
+ """Open a MariaDB connection for the duration of the FastMCP session."""
19
+
20
+ cfg = DatabaseSettings()
21
+ conn = mariadb.connect(
22
+ host=cfg.host,
23
+ port=cfg.port,
24
+ user=cfg.user,
25
+ password=cfg.password,
26
+ database=cfg.database,
27
+ )
28
+ conn.autocommit = True
29
+
30
+ try:
31
+ yield AppContext(conn=conn)
32
+ finally:
33
+ conn.close()
src/mcp_server_mariadb_vector/embeddings/base.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC, abstractmethod
2
+ from enum import Enum
3
+ from typing import List
4
+
5
+
6
+ class EmbeddingProviderType(Enum):
7
+ OPENAI = "openai"
8
+ TEST = "test"
9
+ # SENTENCE_TRANSFORMERS = "sentence-transformers"
10
+
11
+
12
+ class EmbeddingProvider(ABC):
13
+ """Abstract base class for embedding providers."""
14
+
15
+ @abstractmethod
16
+ def length_of_embedding(self) -> int:
17
+ """Get the length of the embedding for a given model."""
18
+ pass
19
+
20
+ @abstractmethod
21
+ def embed_documents(self, documents: List[str]) -> List[List[float]]:
22
+ """Embed a list of documents into vectors."""
23
+ pass
24
+
25
+ @abstractmethod
26
+ def embed_query(self, query: str) -> List[float]:
27
+ """Embed a query into a vector."""
28
+ pass
src/mcp_server_mariadb_vector/embeddings/factory.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from mcp_server_mariadb_vector.embeddings.base import (
2
+ EmbeddingProvider,
3
+ EmbeddingProviderType,
4
+ )
5
+ from mcp_server_mariadb_vector.embeddings.openai import OpenAIEmbeddingProvider
6
+ from mcp_server_mariadb_vector.embeddings.test import TestEmbeddingProvider
7
+ from mcp_server_mariadb_vector.settings import EmbeddingSettings
8
+
9
+
10
+ def create_embedding_provider(settings: EmbeddingSettings) -> EmbeddingProvider:
11
+ """
12
+ Create an instance of the specified embedding provider.
13
+
14
+ Args:
15
+ settings: The settings for the embedding provider.
16
+ """
17
+ if settings.provider == EmbeddingProviderType.OPENAI:
18
+ return OpenAIEmbeddingProvider(settings.model, settings.openai_api_key)
19
+ elif settings.provider == EmbeddingProviderType.TEST:
20
+ return TestEmbeddingProvider()
21
+ else:
22
+ raise ValueError(f"Unsupported embedding provider: {settings.provider}")
src/mcp_server_mariadb_vector/embeddings/openai.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from openai import OpenAI
4
+
5
+ from mcp_server_mariadb_vector.embeddings.base import EmbeddingProvider
6
+
7
+
8
+ class OpenAIEmbeddingProvider(EmbeddingProvider):
9
+ """
10
+ OpenAI implementation of the embedding provider.
11
+
12
+ Args:
13
+ model: The name of the OpenAI model to use.
14
+ """
15
+
16
+ def __init__(self, model: str, api_key: str):
17
+ self.model = model
18
+ self.client = OpenAI(api_key=api_key)
19
+
20
+ def length_of_embedding(self) -> int:
21
+ """Get the length of the embedding for a given model."""
22
+ if self.model == "text-embedding-3-small":
23
+ return 1536
24
+ elif self.model == "text-embedding-3-large":
25
+ return 3072
26
+ else:
27
+ raise ValueError(f"Unknown embedding model: {self.model}")
28
+
29
+ def embed_documents(self, documents: List[str]) -> List[List[float]]:
30
+ """Embed a list of documents into vectors."""
31
+ embeddings = [
32
+ self.client.embeddings.create(
33
+ model=self.model,
34
+ input=document,
35
+ )
36
+ .data[0]
37
+ .embedding
38
+ for document in documents
39
+ ]
40
+ return embeddings
41
+
42
+ def embed_query(self, query: str) -> List[float]:
43
+ """Embed a query into a vector."""
44
+ embedding = self.client.embeddings.create(
45
+ model=self.model,
46
+ input=query,
47
+ )
48
+ return embedding.data[0].embedding
src/mcp_server_mariadb_vector/embeddings/test.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import List
2
+
3
+ from mcp_server_mariadb_vector.embeddings.base import EmbeddingProvider
4
+
5
+
6
+ class TestEmbeddingProvider(EmbeddingProvider):
7
+ """
8
+ Embedding provider for testing.
9
+ """
10
+
11
+ def length_of_embedding(self) -> int:
12
+ return 3
13
+
14
+ def embed_documents(self, documents: List[str]) -> List[List[float]]:
15
+ return [[0.1, 0.2, 0.3]] * len(documents)
16
+
17
+ def embed_query(self, query: str) -> List[float]:
18
+ return [0.1, 0.2, 0.3]
src/mcp_server_mariadb_vector/server.py ADDED
@@ -0,0 +1,257 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import json
3
+ from typing import Annotated, List, Literal
4
+
5
+ import mariadb
6
+ from fastmcp import Context, FastMCP
7
+ from pydantic import Field
8
+
9
+ from mcp_server_mariadb_vector.app_context import app_lifespan
10
+ from mcp_server_mariadb_vector.embeddings.factory import create_embedding_provider
11
+ from mcp_server_mariadb_vector.settings import EmbeddingSettings
12
+
13
+ mcp = FastMCP(
14
+ "Mariadb Vector",
15
+ lifespan=app_lifespan,
16
+ dependencies=["mariadb", "openai", "pydantic", "pydantic-settings"],
17
+ )
18
+
19
+
20
+ embedding_provider = create_embedding_provider(EmbeddingSettings())
21
+
22
+
23
+ @mcp.tool()
24
+ def mariadb_create_vector_store(
25
+ ctx: Context,
26
+ vector_store_name: Annotated[
27
+ str,
28
+ Field(description="The name of the vector store to create"),
29
+ ],
30
+ distance_function: Annotated[
31
+ Literal["euclidean", "cosine"],
32
+ Field(description="The distance function to use."),
33
+ ] = "euclidean",
34
+ ) -> str:
35
+ """Create a vector store in the MariaDB database."""
36
+
37
+ embedding_length = embedding_provider.length_of_embedding()
38
+
39
+ schema_query = f"""
40
+ CREATE TABLE `{vector_store_name}` (
41
+ id BIGINT UNSIGNED PRIMARY KEY AUTO_INCREMENT,
42
+ document LONGTEXT NOT NULL,
43
+ embedding VECTOR({embedding_length}) NOT NULL,
44
+ metadata JSON NOT NULL,
45
+ VECTOR INDEX (embedding) DISTANCE={distance_function}
46
+ )
47
+ """
48
+
49
+ try:
50
+ conn = ctx.request_context.lifespan_context.conn
51
+ with conn.cursor() as cursor:
52
+ cursor.execute(schema_query)
53
+ except mariadb.Error as e:
54
+ return f"Error creating vector store `{vector_store_name}`: {e}"
55
+
56
+ return f"Vector store `{vector_store_name}` created successfully."
57
+
58
+
59
+ def is_vector_store(conn, table: str, embedding_length: int) -> bool:
60
+ """
61
+ True if `table` has the right schema, with vectors of the correct length, and a VECTOR index.
62
+ """
63
+
64
+ with conn.cursor(dictionary=True) as cur:
65
+ # check columns
66
+ cur.execute(f"SHOW COLUMNS FROM `{table}`")
67
+ rows = {r["Field"]: r for r in cur}
68
+
69
+ if set(rows) != {"id", "document", "embedding", "metadata"}:
70
+ return False
71
+
72
+ # id
73
+ id_type = rows["id"]["Type"].lower()
74
+ if id_type != "bigint(20) unsigned":
75
+ return False
76
+ if (
77
+ rows["id"]["Null"] != "NO"
78
+ or rows["id"]["Key"] != "PRI"
79
+ or "auto_increment" not in rows["id"]["Extra"].lower()
80
+ ):
81
+ return False
82
+
83
+ # document
84
+ if (
85
+ rows["document"]["Type"].lower() != "longtext"
86
+ or rows["document"]["Null"] != "NO"
87
+ ):
88
+ return False
89
+
90
+ # embedding
91
+ if (
92
+ rows["embedding"]["Type"].lower() != f"vector({embedding_length})"
93
+ or rows["embedding"]["Null"] != "NO"
94
+ ):
95
+ return False
96
+
97
+ # metadata
98
+ if (
99
+ rows["metadata"]["Type"].lower() != "longtext"
100
+ or rows["metadata"]["Null"] != "NO"
101
+ ):
102
+ return False
103
+
104
+ # check vector index
105
+ cur.execute(f"""
106
+ SHOW INDEX FROM `{table}`
107
+ WHERE Index_type = 'VECTOR' AND Column_name = 'embedding'
108
+ """)
109
+ if cur.fetchone() is None:
110
+ return False
111
+
112
+ return True
113
+
114
+
115
+ @mcp.tool()
116
+ def mariadb_list_vector_stores(ctx: Context) -> str:
117
+ """List all vector stores in a MariaDB database."""
118
+ try:
119
+ conn = ctx.request_context.lifespan_context.conn
120
+ with conn.cursor() as cursor:
121
+ cursor.execute("SHOW TABLES")
122
+ tables = [table[0] for table in cursor]
123
+ except mariadb.Error as e:
124
+ return f"Error listing vector stores: {e}"
125
+
126
+ embedding_length = embedding_provider.length_of_embedding()
127
+ vector_stores = [
128
+ table for table in tables if is_vector_store(conn, table, embedding_length)
129
+ ]
130
+
131
+ return "Vector stores: " + ", ".join(vector_stores)
132
+
133
+
134
+ @mcp.tool()
135
+ def mariadb_delete_vector_store(
136
+ ctx: Context,
137
+ vector_store_name: Annotated[
138
+ str, Field(description="The name of the vector store to delete.")
139
+ ],
140
+ ) -> str:
141
+ """Delete a vector store in the MariaDB database."""
142
+
143
+ try:
144
+ conn = ctx.request_context.lifespan_context.conn
145
+ with conn.cursor() as cursor:
146
+ cursor.execute(f"DROP TABLE `{vector_store_name}`")
147
+ except mariadb.Error as e:
148
+ return f"Error deleting vector store `{vector_store_name}`: {e}"
149
+
150
+ return f"Vector store `{vector_store_name}` deleted successfully."
151
+
152
+
153
+ @mcp.tool()
154
+ def mariadb_insert_documents(
155
+ ctx: Context,
156
+ vector_store_name: Annotated[
157
+ str, Field(description="The name of the vector store to insert documents into.")
158
+ ],
159
+ documents: Annotated[
160
+ List[str], Field(description="The documents to insert into the vector store.")
161
+ ],
162
+ metadata: Annotated[
163
+ List[dict], Field(description="The metadata of the documents to insert.")
164
+ ],
165
+ ) -> str:
166
+ """Insert a document into a vector store."""
167
+
168
+ embeddings = embedding_provider.embed_documents(documents)
169
+
170
+ metadata_json = [json.dumps(metadata) for metadata in metadata]
171
+
172
+ insert_query = f"""
173
+ INSERT INTO `{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)
174
+ """
175
+ try:
176
+ conn = ctx.request_context.lifespan_context.conn
177
+ with conn.cursor() as cursor:
178
+ cursor.executemany(
179
+ insert_query, list(zip(documents, embeddings, metadata_json))
180
+ )
181
+ except mariadb.Error as e:
182
+ return f"Error inserting documents`{vector_store_name}`: {e}"
183
+
184
+ return f"Documents inserted into `{vector_store_name}` successfully."
185
+
186
+
187
+ @mcp.tool()
188
+ def mariadb_search_vector_store(
189
+ ctx: Context,
190
+ query: Annotated[str, Field(description="The query to search for.")],
191
+ vector_store_name: Annotated[
192
+ str, Field(description="The name of the vector store to search.")
193
+ ],
194
+ k: Annotated[int, Field(gt=0, description="The number of results to return.")] = 5,
195
+ ) -> str:
196
+ """Search a vector store for the most similar documents to a query."""
197
+
198
+ embedding = embedding_provider.embed_query(query)
199
+
200
+ search_query = f"""
201
+ SELECT
202
+ document,
203
+ metadata,
204
+ VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance
205
+ FROM `{vector_store_name}`
206
+ ORDER BY distance ASC
207
+ LIMIT %s
208
+ """
209
+
210
+ try:
211
+ conn = ctx.request_context.lifespan_context.conn
212
+ with conn.cursor(buffered=True) as cursor:
213
+ cursor.execute(
214
+ search_query,
215
+ (str(embedding), k),
216
+ )
217
+ rows = cursor.fetchall()
218
+ except mariadb.Error as e:
219
+ return f"Error searching vector store`{vector_store_name}`: {e}"
220
+
221
+ if not rows:
222
+ return "No similar context found."
223
+
224
+ return "\n\n".join(
225
+ f"Document: {row[0]}\nMetadata: {json.loads(row[1])}\nDistance: {row[2]}"
226
+ for row in rows
227
+ )
228
+
229
+
230
+ def main():
231
+ parser = argparse.ArgumentParser()
232
+ parser.add_argument(
233
+ "--transport",
234
+ choices=["stdio", "sse"],
235
+ default="stdio",
236
+ )
237
+ parser.add_argument(
238
+ "--host",
239
+ type=str,
240
+ default="127.0.0.1",
241
+ )
242
+ parser.add_argument(
243
+ "--port",
244
+ type=int,
245
+ default=8000,
246
+ )
247
+
248
+ args = parser.parse_args()
249
+
250
+ if args.transport == "sse":
251
+ mcp.run(transport=args.transport, host=args.host, port=args.port)
252
+ else:
253
+ mcp.run(transport=args.transport)
254
+
255
+
256
+ if __name__ == "__main__":
257
+ main()
src/mcp_server_mariadb_vector/settings.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Optional
2
+
3
+ from pydantic import Field
4
+ from pydantic_settings import BaseSettings
5
+
6
+ from mcp_server_mariadb_vector.embeddings.base import EmbeddingProviderType
7
+
8
+
9
+ class DatabaseSettings(BaseSettings):
10
+ host: str = Field(default="127.0.0.1", alias="MARIADB_HOST")
11
+ port: int = Field(default=3306, alias="MARIADB_PORT")
12
+ user: str = Field(..., alias="MARIADB_USER")
13
+ password: str = Field(..., alias="MARIADB_PASSWORD")
14
+ database: str = Field(..., alias="MARIADB_DATABASE")
15
+
16
+
17
+ class EmbeddingSettings(BaseSettings):
18
+ provider: EmbeddingProviderType = Field(
19
+ default=EmbeddingProviderType.OPENAI, alias="EMBEDDING_PROVIDER"
20
+ )
21
+ model: str = Field(default="text-embedding-3-small", alias="EMBEDDING_MODEL")
22
+ openai_api_key: Optional[str] = Field(default=None, alias="OPENAI_API_KEY")