File size: 7,509 Bytes
644bdfe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import argparse
import json
from typing import Annotated, List, Literal

import mariadb
from fastmcp import Context, FastMCP
from pydantic import Field

from mcp_server_mariadb_vector.app_context import app_lifespan
from mcp_server_mariadb_vector.embeddings.factory import create_embedding_provider
from mcp_server_mariadb_vector.settings import EmbeddingSettings

mcp = FastMCP(
    "Mariadb Vector",
    lifespan=app_lifespan,
    dependencies=["mariadb", "openai", "pydantic", "pydantic-settings"],
)


embedding_provider = create_embedding_provider(EmbeddingSettings())


@mcp.tool()
def mariadb_create_vector_store(
    ctx: Context,
    vector_store_name: Annotated[
        str,
        Field(description="The name of the vector store to create"),
    ],
    distance_function: Annotated[
        Literal["euclidean", "cosine"],
        Field(description="The distance function to use."),
    ] = "euclidean",
) -> str:
    """Create a vector store in the MariaDB database."""

    embedding_length = embedding_provider.length_of_embedding()

    schema_query = f"""
    CREATE TABLE `{vector_store_name}` (
        id BIGINT UNSIGNED PRIMARY KEY AUTO_INCREMENT,
        document LONGTEXT NOT NULL,
        embedding VECTOR({embedding_length}) NOT NULL,
        metadata JSON NOT NULL,
        VECTOR INDEX (embedding) DISTANCE={distance_function}
    )
    """

    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.execute(schema_query)
    except mariadb.Error as e:
        return f"Error creating vector store `{vector_store_name}`: {e}"

    return f"Vector store `{vector_store_name}` created successfully."


def is_vector_store(conn, table: str, embedding_length: int) -> bool:
    """
    True if `table` has the right schema, with vectors of the correct length, and a VECTOR index.
    """

    with conn.cursor(dictionary=True) as cur:
        # check columns
        cur.execute(f"SHOW COLUMNS FROM `{table}`")
        rows = {r["Field"]: r for r in cur}

        if set(rows) != {"id", "document", "embedding", "metadata"}:
            return False

        # id
        id_type = rows["id"]["Type"].lower()
        if id_type != "bigint(20) unsigned":
            return False
        if (
            rows["id"]["Null"] != "NO"
            or rows["id"]["Key"] != "PRI"
            or "auto_increment" not in rows["id"]["Extra"].lower()
        ):
            return False

        # document
        if (
            rows["document"]["Type"].lower() != "longtext"
            or rows["document"]["Null"] != "NO"
        ):
            return False

        # embedding
        if (
            rows["embedding"]["Type"].lower() != f"vector({embedding_length})"
            or rows["embedding"]["Null"] != "NO"
        ):
            return False

        # metadata
        if (
            rows["metadata"]["Type"].lower() != "longtext"
            or rows["metadata"]["Null"] != "NO"
        ):
            return False

        # check vector index
        cur.execute(f"""
            SHOW INDEX FROM `{table}`
            WHERE Index_type = 'VECTOR' AND Column_name = 'embedding'
        """)
        if cur.fetchone() is None:
            return False

    return True


@mcp.tool()
def mariadb_list_vector_stores(ctx: Context) -> str:
    """List all vector stores in a MariaDB database."""
    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.execute("SHOW TABLES")
            tables = [table[0] for table in cursor]
    except mariadb.Error as e:
        return f"Error listing vector stores: {e}"

    embedding_length = embedding_provider.length_of_embedding()
    vector_stores = [
        table for table in tables if is_vector_store(conn, table, embedding_length)
    ]

    return "Vector stores: " + ", ".join(vector_stores)


@mcp.tool()
def mariadb_delete_vector_store(
    ctx: Context,
    vector_store_name: Annotated[
        str, Field(description="The name of the vector store to delete.")
    ],
) -> str:
    """Delete a vector store in the MariaDB database."""

    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.execute(f"DROP TABLE `{vector_store_name}`")
    except mariadb.Error as e:
        return f"Error deleting vector store `{vector_store_name}`: {e}"

    return f"Vector store `{vector_store_name}` deleted successfully."


@mcp.tool()
def mariadb_insert_documents(
    ctx: Context,
    vector_store_name: Annotated[
        str, Field(description="The name of the vector store to insert documents into.")
    ],
    documents: Annotated[
        List[str], Field(description="The documents to insert into the vector store.")
    ],
    metadata: Annotated[
        List[dict], Field(description="The metadata of the documents to insert.")
    ],
) -> str:
    """Insert a document into a vector store."""

    embeddings = embedding_provider.embed_documents(documents)

    metadata_json = [json.dumps(metadata) for metadata in metadata]

    insert_query = f"""
    INSERT INTO `{vector_store_name}` (document, embedding, metadata) VALUES (%s, VEC_FromText(%s), %s)
    """
    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor() as cursor:
            cursor.executemany(
                insert_query, list(zip(documents, embeddings, metadata_json))
            )
    except mariadb.Error as e:
        return f"Error inserting documents`{vector_store_name}`: {e}"

    return f"Documents inserted into `{vector_store_name}` successfully."


@mcp.tool()
def mariadb_search_vector_store(
    ctx: Context,
    query: Annotated[str, Field(description="The query to search for.")],
    vector_store_name: Annotated[
        str, Field(description="The name of the vector store to search.")
    ],
    k: Annotated[int, Field(gt=0, description="The number of results to return.")] = 5,
) -> str:
    """Search a vector store for the most similar documents to a query."""

    embedding = embedding_provider.embed_query(query)

    search_query = f"""
    SELECT 
        document,
        metadata,
        VEC_DISTANCE_EUCLIDEAN(embedding, VEC_FromText(%s)) AS distance
    FROM `{vector_store_name}`
    ORDER BY distance ASC
    LIMIT %s
    """

    try:
        conn = ctx.request_context.lifespan_context.conn
        with conn.cursor(buffered=True) as cursor:
            cursor.execute(
                search_query,
                (str(embedding), k),
            )
            rows = cursor.fetchall()
    except mariadb.Error as e:
        return f"Error searching vector store`{vector_store_name}`: {e}"

    if not rows:
        return "No similar context found."

    return "\n\n".join(
        f"Document: {row[0]}\nMetadata: {json.loads(row[1])}\nDistance: {row[2]}"
        for row in rows
    )


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--transport",
        choices=["stdio", "sse"],
        default="stdio",
    )
    parser.add_argument(
        "--host",
        type=str,
        default="127.0.0.1",
    )
    parser.add_argument(
        "--port",
        type=int,
        default=8000,
    )

    args = parser.parse_args()

    if args.transport == "sse":
        mcp.run(transport=args.transport, host=args.host, port=args.port)
    else:
        mcp.run(transport=args.transport)


if __name__ == "__main__":
    main()