Jacqkues commited on
Commit
d25ee4b
·
verified ·
1 Parent(s): c555a85

Upload 11 files

Browse files
agent.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from services.ai import rank_tables,generate_sql,generate_answer , correct_sql , evaluate_difficulty
2
+ from services.utils import filter_tables
3
+ from openai import OpenAI
4
+ from database import Database
5
+ from filesource import FileSource
6
+ import os
7
+ MAX_TABLE = 3
8
+
9
+ client = OpenAI(
10
+ base_url=os.getenv("LLM_ENDPOINT"),
11
+ api_key=os.getenv("LLM_KEY")
12
+ )
13
+
14
+ def run_agent(database,prompt):
15
+
16
+ retry = 5
17
+ tables = database.get_tables_array()
18
+
19
+ use_thinking = False
20
+
21
+
22
+ if len(tables) > MAX_TABLE:
23
+ print(f"using reranking because number of tables is greater than {MAX_TABLE}")
24
+ ranked = rank_tables(prompt,tables)
25
+ tables = filter_tables(0,ranked)[:MAX_TABLE]
26
+
27
+ dif = int(evaluate_difficulty(client,prompt))
28
+ if dif > 7:
29
+ print("difficulty is > 7 so we enable thinking mode")
30
+ use_thinking = True
31
+ sql = generate_sql(client,prompt,tables,use_thinking)
32
+ nb_try = 0
33
+ success = False
34
+ while nb_try < retry and not success:
35
+ nb_try = nb_try + 1
36
+ try:
37
+ print("try to launch sql request")
38
+ result = database.query(sql)
39
+ success = True
40
+ except Exception as e:
41
+ print(f"Error : {e}")
42
+ print("Try to self correct...")
43
+ error = f"{type(e).__name__} - {str(e)}"
44
+ if nb_try < retry - 2:
45
+ sql = correct_sql(client,prompt,sql,tables,error,True)
46
+ else:
47
+ sql = correct_sql(client,prompt,sql,tables,error,False)
48
+
49
+ print(sql)
50
+
51
+ if success:
52
+ print(result.to_markdown())
53
+ return generate_answer(client,sql,prompt,result.to_markdown(),use_thinking)
54
+
55
+
56
+ # db = Database("mysql://user:password@localhost:3306/Pokemon")
57
+ # db.connect()
58
+ # file = FileSource("./Wines.csv")
59
+ # file.connect()
60
+ # print(run_agent(file,"What is the quality og the win with the less of alcohol ?"))
61
+
62
+
app.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ from database import Database
3
+ from filesource import FileSource
4
+ from agent import run_agent
5
+ from services.utils import get_db_scheme_from_uri
6
+ source = None
7
+
8
+
9
+ def connect_to_file(file):
10
+ global source
11
+ try:
12
+ source = FileSource(file.name)
13
+ status = source.connect()
14
+ schema = source._pretify_schema()
15
+ status = "Connection successful!"
16
+ except Exception as e:
17
+ schema = ""
18
+ status = f"Error: {str(e)}"
19
+ return schema, status
20
+
21
+
22
+ def connect_to_database(db_url):
23
+ global source
24
+ try:
25
+ dialect = get_db_scheme_from_uri(db_url)
26
+ source = Database(db_url,dialect)
27
+ status = source.connect()
28
+ schema = source._pretify_schema()
29
+ status = "Connection successful!"
30
+ except Exception as e:
31
+ schema = ""
32
+ status = f"Error: {str(e)}"
33
+ return schema, status
34
+
35
+ # Function to add user message to chat history
36
+ def user(user_message, chat_history):
37
+ chat_history.append({"role": "user", "content": user_message})
38
+ return "", chat_history
39
+
40
+ # Function to generate a bot response
41
+ def bot(chat_history):
42
+
43
+ if source is None:
44
+ chat_history.append({"role":"assistant","content":"please connect to a database before asking question"})
45
+ yield chat_history
46
+ else:
47
+ answer = run_agent(source,chat_history[-1]['content'])
48
+ chat_history.append({"role":"assistant","content":""})
49
+
50
+ for chunk in answer:
51
+ chat_history[-1]['content'] += chunk
52
+ yield chat_history
53
+
54
+ # Create the Gradio interface
55
+ with gr.Blocks(theme=gr.themes.Default(), css="""
56
+ .gr-button { margin: 5px; border-radius:16px; }
57
+ .gr-textbox, .gr-text-area, .gr-dropdown, .gr-json { border-radius: 8px; }
58
+ .gr-row { gap: 10px; }
59
+ .gr-tab { border-radius: 8px; }
60
+ .status-text { font-size: 0.9em; color: #555; }
61
+ .gr-json { max-height: 300px; overflow-y: auto; } /* Added scrolling for JSON */
62
+ """) as demo:
63
+ gr.Markdown(
64
+ f"""
65
+ # 🤖 MCP DB Answer
66
+ Your mcp server that allow you to talk to any database
67
+
68
+
69
+ Powered by Ibis it support : PostgreSQL , SQLite , MySQL , MSSQL , ClickHouse , BigQuery and many other
70
+
71
+
72
+ Also support .CSV and .parquet files
73
+
74
+ """,
75
+ elem_classes=["header"]
76
+ )
77
+
78
+ with gr.Column(scale=3):
79
+ with gr.Tabs():
80
+ with gr.TabItem("💬 Chat"):
81
+ with gr.Group():
82
+ main_chat_disp = gr.Chatbot(
83
+ label=None, height=600,
84
+ avatar_images=(None, "https://huggingface.co/spaces/Space-Share/bucket/resolve/main/images/pfp.webp"),
85
+ show_copy_button=True, render_markdown=True, sanitize_html=True, type='messages'
86
+ )
87
+ with gr.Row(variant="compact"):
88
+ user_msg_tb = gr.Textbox(
89
+ show_label=False, placeholder="Talk with your data...",
90
+ scale=7, lines=1, max_lines=3
91
+ )
92
+ send_btn = gr.Button("Send", variant="primary", scale=1, min_width=100)
93
+ with gr.TabItem("Config"):
94
+ with gr.Row():
95
+ # Left column for database configuration.
96
+ with gr.Column(scale=1):
97
+ gr.Markdown("## Database Configuration")
98
+ # Textbox for entering the database URL.
99
+ db_url_tb = gr.Textbox(
100
+ show_label=True, label="Database URL", placeholder="Enter the URL to connect to the database..."
101
+ )
102
+ # Button to connect to the database.
103
+ connect_btn = gr.Button("Connect", variant="primary")
104
+
105
+ file_uploader = gr.File(
106
+ label="Upload File", file_types=[".csv", ".parquet", ".xls", ".xlsx"]
107
+ )
108
+ # Button to connect to the database.
109
+ load_btn = gr.Button("Load", variant="primary")
110
+
111
+ # Right column for displaying the database schema and status message.
112
+ with gr.Column(scale=3):
113
+ gr.Markdown("## Database Schema")
114
+ # Textarea to display the database schema.
115
+ schema_ta = gr.TextArea(
116
+ show_label=False, placeholder="Database schema will be displayed here...",
117
+ lines=20, max_lines=50, interactive=False
118
+ )
119
+ # Textbox to display the status message.
120
+ status_tb = gr.Textbox(
121
+ show_label=False, placeholder="Status message will be displayed here...",
122
+ lines=1, max_lines=1, interactive=False, elem_classes=["status-text"]
123
+ )
124
+ connect_btn.click(fn=connect_to_database, inputs=db_url_tb, outputs=[schema_ta, status_tb])
125
+ load_btn.click(fn=connect_to_file, inputs=file_uploader, outputs=[schema_ta, status_tb])
126
+ send_btn.click(fn=user, inputs=[user_msg_tb, main_chat_disp], outputs=[user_msg_tb, main_chat_disp], queue=False).then(
127
+ fn=bot, inputs=main_chat_disp, outputs=main_chat_disp
128
+ )
129
+
130
+ if __name__ == "__main__":
131
+ demo.launch(mcp_server=True)
database.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ibis
2
+ import sqlglot
3
+ from sqlglot import optimizer
4
+ from sqlglot.optimizer import qualify
5
+ from sqlglot.errors import OptimizeError, ParseError
6
+
7
+ class Database:
8
+
9
+ def __init__(self,connection_url,engine_dialect = "mysql") -> None:
10
+
11
+ self._connect_url = connection_url
12
+ self.engine_dialect = engine_dialect
13
+ self._tables_docs = {}
14
+ self._table_exemple = {}
15
+
16
+ def connect(self):
17
+ try:
18
+ self._con = ibis.connect(self._connect_url)
19
+ return f"✅ Connection to {self._connect_url} OK!"
20
+ except Exception as e:
21
+ #raise f"❌ Connection failed: {type(e).__name__} - {str(e)}"
22
+ raise e
23
+
24
+ def _optimize_query(self,sql,schema):
25
+
26
+ optimized_expression = optimizer.optimize(sql, schema=schema, dialect=self.engine_dialect)
27
+ optimized_sql = optimized_expression.sql(dialect=self.engine_dialect)
28
+ return optimized_sql
29
+
30
+ def _pretify_table(self,table,columns):
31
+ out = ""
32
+ if table in self._tables_docs.keys():
33
+ out += f"## Documentation \n{self._tables_docs[table]}\n"
34
+
35
+ if table in self._table_exemple.keys():
36
+ out += f"## Exemple \n{self._table_exemple[table]}"
37
+ out += f"Table ({table}) with {len(columns)} fields : \n"
38
+ for field in columns.keys():
39
+ out += f"\t{field} of type : {columns[field]}\n"
40
+ return out
41
+
42
+ def add_table_documentation(self,table_name,documentation):
43
+ self._tables_docs[table_name] = documentation
44
+ def add_table_exemple(self,table_name,exemples):
45
+ self._table_exemple[table_name] = exemples
46
+
47
+ def get_tables_array(self):
48
+ schema = self._build_schema()
49
+ array = []
50
+ for table in schema.keys():
51
+ array.append(self._pretify_table(table,schema[table]))
52
+ return array
53
+
54
+ def _pretify_schema(self):
55
+ out = ""
56
+ schema = self._build_schema()
57
+ for table in schema.keys():
58
+ out += self._pretify_table(table,schema[table])
59
+ out += "\n"
60
+ return out
61
+ def _build_schema(self):
62
+
63
+ tables = self._con.list_tables()
64
+ schema = {}
65
+ for table_name in tables:
66
+
67
+ try:
68
+ table_expr = self._con.table(table_name)
69
+ table_schema = table_expr.schema()
70
+ columns = {col: str(dtype) for col, dtype in table_schema.items()}
71
+ schema[table_name] = columns
72
+
73
+ except Exception as e:
74
+
75
+ print(f"Warning: Could not retrieve schema for table '{table_name}': {e}")
76
+ return schema
77
+
78
+ def query(self, sql_query):
79
+ schema = self._build_schema()
80
+ print(sql_query)
81
+ try:
82
+ expression = sqlglot.parse_one(sql_query, read=self.engine_dialect)
83
+ except Exception as e:
84
+ raise e
85
+
86
+ try:
87
+ optimized_query = self._optimize_query(expression, schema)
88
+ final_query = optimized_query
89
+ except Exception as e:
90
+ final_query = expression.sql(dialect=self.engine_dialect)
91
+
92
+ try:
93
+ expr = self._con.sql(final_query, dialect=self.engine_dialect)
94
+ result_df = expr.execute()
95
+ return result_df
96
+ except Exception as e:
97
+ raise e
98
+
99
+
100
+ # db = Database("mysql://user:password@localhost:3306/Pokemon")
101
+ # db.connect()
102
+ # schema = db._build_schema()
103
+ # db.add_table_documentation("Defense","This is a super table")
104
+ # db.add_table_exemple("Defense","caca")
105
+ # db.add_table_exemple("Joueur","ezofkzrfp")
106
+ # for table in schema.keys():
107
+ # print(db._pretify_table(table,schema[table]))
filesource.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import ibis
2
+ import sqlglot
3
+ from sqlglot import optimizer
4
+ from sqlglot.optimizer import qualify
5
+ from sqlglot.errors import OptimizeError, ParseError
6
+ from services.utils import extract_filename
7
+ class FileSource:
8
+
9
+ def __init__(self,file_path,file_type="csv") -> None:
10
+ self.file_path = file_path
11
+ self.file_type = file_type.lower()
12
+ self._tables_docs = {}
13
+ self._table_exemple = {}
14
+ self.engine_dialect = "duckdb"
15
+
16
+
17
+ def connect(self):
18
+ try:
19
+ self._con = ibis.connect("duckdb://")
20
+ name = extract_filename(self.file_path)
21
+ ext = name.split(".")[1]
22
+ table = name.split(".")[0]
23
+ if ext == "csv":
24
+ self._table = self._con.read_csv(self.file_path,table_name=table)
25
+ elif ext == "parquet":
26
+ self._table = self._con.read_parquet(self.file_path,table_name=table)
27
+ self._schema = self._table.schema()
28
+ return f"✅ Connection to {name} OK!"
29
+ except Exception as e:
30
+ raise e
31
+
32
+ def _optimize_query(self,sql,schema):
33
+
34
+ optimized_expression = optimizer.optimize(sql, schema=schema, dialect=self.engine_dialect)
35
+ optimized_sql = optimized_expression.sql(dialect=self.engine_dialect)
36
+ return optimized_sql
37
+
38
+ def _pretify_table(self,table,columns):
39
+ out = ""
40
+ if table in self._tables_docs.keys():
41
+ out += f"## Documentation \n{self._tables_docs[table]}\n"
42
+
43
+ if table in self._table_exemple.keys():
44
+ out += f"## Exemple \n{self._table_exemple[table]}"
45
+ out += f"Table ({table}) with {len(columns)} fields : \n"
46
+ for field in columns.keys():
47
+ out += f"\t{field} of type : {columns[field]}\n"
48
+ return out
49
+
50
+ def add_table_documentation(self,table_name,documentation):
51
+ self._tables_docs[table_name] = documentation
52
+ def add_table_exemple(self,table_name,exemples):
53
+ self._table_exemple[table_name] = exemples
54
+
55
+ def get_tables_array(self):
56
+ schema = self._build_schema()
57
+ array = []
58
+ for table in schema.keys():
59
+ array.append(self._pretify_table(table,schema[table]))
60
+ return array
61
+
62
+ def _pretify_schema(self):
63
+ out = ""
64
+ schema = self._build_schema()
65
+ for table in schema.keys():
66
+ out += self._pretify_table(table,schema[table])
67
+ out += "\n"
68
+ return out
69
+ def _build_schema(self):
70
+
71
+ tables = self._con.list_tables()
72
+ schema = {}
73
+ for table_name in tables:
74
+
75
+ try:
76
+ table_expr = self._con.table(table_name)
77
+ table_schema = table_expr.schema()
78
+ columns = {col: str(dtype) for col, dtype in table_schema.items()}
79
+ schema[table_name] = columns
80
+
81
+ except Exception as e:
82
+
83
+ print(f"Warning: Could not retrieve schema for table '{table_name}': {e}")
84
+ return schema
85
+
86
+ def query(self, sql_query):
87
+ schema = self._build_schema()
88
+ print(sql_query)
89
+ try:
90
+ expression = sqlglot.parse_one(sql_query, read=self.engine_dialect)
91
+ except Exception as e:
92
+ raise e
93
+
94
+ try:
95
+ optimized_query = self._optimize_query(expression, schema)
96
+ final_query = optimized_query
97
+ except Exception as e:
98
+ final_query = expression.sql(dialect=self.engine_dialect)
99
+
100
+ try:
101
+ expr = self._con.sql(final_query, dialect=self.engine_dialect)
102
+ result_df = expr.execute()
103
+ return result_df
104
+ except Exception as e:
105
+ raise e
106
+
107
+
108
+ # db = Database("mysql://user:password@localhost:3306/Pokemon")
109
+ # db.connect()
110
+ # schema = db._build_schema()
111
+ # db.add_table_documentation("Defense","This is a super table")
112
+ # db.add_table_exemple("Defense","caca")
113
+ # db.add_table_exemple("Joueur","ezofkzrfp")
114
+ # for table in schema.keys():
115
+ # print(db._pretify_table(table,schema[table]))
116
+ # file = FileSource("./Wines.csv")
117
+ # file.connect()
118
+
119
+ # schema = file._build_schema()
120
+ # # db.add_table_exemple("Defense","caca")
121
+ # # db.add_table_exemple("Joueur","ezofkzrfp")
122
+ # for table in schema.keys():
123
+ # print(file._pretify_table(table,schema[table]))
124
+
125
+ # res = file.query("SELECT * FROM Wines;")
126
+ # print(len(res))
modal/rerank_service.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+ MINUTES = 60 # seconds
3
+ MODEL_REPO_ID = "Qwen/Qwen3-Reranker-4B"
4
+ rerank_image = (
5
+ modal.Image.debian_slim(python_version="3.12")
6
+ .pip_install(
7
+ "transformers==4.51.0",
8
+ "huggingface_hub[hf_transfer]",
9
+ "fastapi[standard]",
10
+ "torch"
11
+ )
12
+ .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
13
+ )
14
+
15
+ hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
16
+
17
+ with rerank_image.imports():
18
+ import torch
19
+ from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
20
+
21
+ app = modal.App("qwen3-rerank-service")
22
+
23
+ @app.function(image=rerank_image,volumes = {
24
+ "/root/.cache/huggingface":hf_cache_vol
25
+ })
26
+ def download_model():
27
+ from huggingface_hub import snapshot_download
28
+ loc = snapshot_download(repo_id=MODEL_REPO_ID)
29
+ print(f"Saved model to {loc}")
30
+
31
+ @app.cls(image=rerank_image,gpu="A100-40GB",volumes = {
32
+ "/root/.cache/huggingface":hf_cache_vol
33
+ })
34
+ class RerankerService:
35
+
36
+ @modal.enter()
37
+ def load_model(self):
38
+ self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left')
39
+ self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()
40
+
41
+ @modal.method()
42
+ def rank(self,query,documents):
43
+ max_length = 8192
44
+ prefix = "<|im_start|>system\nJudge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
45
+ suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
46
+ prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
47
+ suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
48
+ token_false_id = self.tokenizer.convert_tokens_to_ids("no")
49
+ token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
50
+ def format_instruction(instruction, query, doc):
51
+ if instruction is None:
52
+ instruction = 'Given a web search query, retrieve relevant passages that answer the query'
53
+ return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
54
+
55
+ def process_inputs(pairs):
56
+ inputs = self.tokenizer(
57
+ pairs, padding=False, truncation='longest_first',
58
+ return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
59
+ )
60
+ for i, ele in enumerate(inputs['input_ids']):
61
+ inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
62
+ inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
63
+ for key in inputs:
64
+ inputs[key] = inputs[key].to(self.model.device)
65
+ return inputs
66
+
67
+ @torch.no_grad()
68
+ def compute_logits(inputs):
69
+ logits = self.model(**inputs).logits[:, -1, :]
70
+ true_vector = logits[:, token_true_id]
71
+ false_vector = logits[:, token_false_id]
72
+ batch_scores = torch.stack([false_vector, true_vector], dim=1)
73
+ batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
74
+ return batch_scores[:, 1].exp().tolist()
75
+
76
+ instruction = "Given a user query find the usefull tables in order to build an sql request"
77
+ pairs = [format_instruction(instruction, query, doc) for doc in documents]
78
+
79
+ inputs = process_inputs(pairs)
80
+ scores = compute_logits(inputs)
81
+
82
+ return scores
83
+
84
+ @app.function(
85
+ image=modal.Image.debian_slim(python_version="3.12")
86
+ .pip_install("fastapi[standard]==0.115.4")
87
+ )
88
+ @modal.asgi_app(label="rerank-endpoint")
89
+ def fastapi_app():
90
+ from fastapi import FastAPI, Request, Response
91
+ from fastapi.staticfiles import StaticFiles
92
+
93
+ web_app = FastAPI()
94
+
95
+ # The endpoint for the prediction function takes an image as a
96
+ # [data URI](https://en.wikipedia.org/wiki/Data_URI_scheme)
97
+ # and returns another image, also as a data URI:
98
+
99
+ @web_app.post("/predict")
100
+ async def predict(request: Request):
101
+ # Takes a webcam image as a datauri, returns a bounding box image as a datauri
102
+ body = await request.body()
103
+ output_data = RerankerService().rank.remote("What is the capital of China?",["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",])
104
+ return Response(content=output_data)
105
+
106
+ return web_app
modal/rerank_service_vllm.py ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+ import logging
3
+ app = modal.App("qwen-reranker-vllm")
4
+ hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
5
+ vllm_cache_vol = modal.Volume.from_name("vllm-cache")
6
+ MINUTES = 60 # seconds
7
+
8
+ vllm_image = (
9
+ modal.Image.debian_slim(python_version="3.12")
10
+ .pip_install(
11
+ "vllm==0.8.5",
12
+ "transformers",
13
+ "torch",
14
+ "fastapi[all]",
15
+ "pydantic"
16
+ )
17
+ .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
18
+ )
19
+
20
+
21
+ with vllm_image.imports():
22
+ from transformers import AutoTokenizer
23
+ from vllm import LLM, SamplingParams
24
+ from vllm.inputs.data import TokensPrompt
25
+ import torch
26
+ import math
27
+
28
+ @app.cls(image=vllm_image,
29
+ gpu="A100-40GB",
30
+ scaledown_window=15 * MINUTES, # how long should we stay up with no requests?
31
+ timeout=10 * MINUTES,
32
+ volumes = {
33
+ "/root/.cache/huggingface":hf_cache_vol,
34
+ "/root/.cache/vllm": vllm_cache_vol,
35
+ })
36
+ class Reranker:
37
+ @modal.enter()
38
+ def load_reranker(self):
39
+ logging.info("in the rank function")
40
+ self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")
41
+ self.tokenizer.padding_side = "left"
42
+ self.tokenizer.pad_token = self.tokenizer.eos_token
43
+ self.model = LLM(
44
+ model="Qwen/Qwen3-Reranker-4B",
45
+ tensor_parallel_size=torch.cuda.device_count(),
46
+ max_model_len=10000,
47
+ enable_prefix_caching=True,
48
+ gpu_memory_utilization=0.8
49
+ )
50
+ self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
51
+ self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
52
+ self.max_length = 8192
53
+ self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
54
+ self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]
55
+ self.sampling_params = SamplingParams(
56
+ temperature=0,
57
+ max_tokens=1,
58
+ logprobs=20,
59
+ allowed_token_ids=[self.true_token, self.false_token],
60
+ )
61
+
62
+ def format_instruction(self, instruction, query, doc):
63
+ return [
64
+ {"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""},
65
+ {"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"}
66
+ ]
67
+
68
+ def process_inputs(self,pairs, instruction):
69
+ messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
70
+ messages = self.tokenizer.apply_chat_template(
71
+ messages, tokenize=True, add_generation_prompt=False, enable_thinking=False
72
+ )
73
+ messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages]
74
+ messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
75
+ return messages
76
+
77
+ def compute_logits(self, messages):
78
+ outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False)
79
+ scores = []
80
+ for i in range(len(outputs)):
81
+ final_logits = outputs[i].outputs[0].logprobs[-1]
82
+ token_count = len(outputs[i].outputs[0].token_ids)
83
+ if self.true_token not in final_logits:
84
+ true_logit = -10
85
+ else:
86
+ true_logit = final_logits[self.true_token].logprob
87
+ if self.false_token not in final_logits:
88
+ false_logit = -10
89
+ else:
90
+ false_logit = final_logits[self.false_token].logprob
91
+ true_score = math.exp(true_logit)
92
+ false_score = math.exp(false_logit)
93
+ score = true_score / (true_score + false_score)
94
+ scores.append(score)
95
+ return scores
96
+
97
+ @modal.method()
98
+ def rerank(self, query, documents,task):
99
+ #task = 'Given a web search query, retrieve relevant passages that answer the query'
100
+ pairs = [(query, doc) for doc in documents]
101
+ inputs = self.process_inputs(pairs, task)
102
+ scores = self.compute_logits( inputs)
103
+
104
+ return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)]
105
+
106
+ @app.function(
107
+ image=modal.Image.debian_slim(python_version="3.12")
108
+ .pip_install("fastapi[standard]==0.115.4","pydantic")
109
+ )
110
+ @modal.asgi_app(label="rerank-endpoint")
111
+ def fastapi_app():
112
+ from pydantic import BaseModel
113
+ from fastapi import FastAPI, Request, Response
114
+ from fastapi.responses import JSONResponse
115
+ from typing import List
116
+
117
+ web_app = FastAPI()
118
+ reranker = Reranker()
119
+ class ScoringResult(BaseModel):
120
+ score: float
121
+ content: str
122
+
123
+ class RankingRequest(BaseModel):
124
+ task:str
125
+ query: str
126
+ documents: List[str]
127
+ @web_app.post("/rank",response_model=List[ScoringResult])
128
+ async def predict(payload: RankingRequest):
129
+ logging.info("call the rank function")
130
+ query = payload.query
131
+ documents = payload.documents
132
+ task = payload.task
133
+ output_data = reranker.rerank.remote(query,documents,task)
134
+ return JSONResponse(content=output_data)
135
+
136
+ return web_app
modal/vllm_service.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import modal
2
+
3
+ vllm_image = (
4
+ modal.Image.debian_slim(python_version="3.12")
5
+ .pip_install(
6
+ "vllm==0.7.2",
7
+ "transformers==4.51.0",
8
+ "huggingface_hub[hf_transfer]",
9
+ "flashinfer-python==0.2.0.post2",
10
+ extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
11
+ )
12
+ .env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers
13
+ )
14
+ vllm_image = vllm_image.env({"VLLM_USE_V1": "1"})
15
+
16
+ hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
17
+ vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
18
+
19
+ app = modal.App("example-vllm-openai-compatible")
20
+
21
+ N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count
22
+ API_KEY = "super-secret-key-mcp-hackathon" # api key, for auth. for production use, replace with a modal.Secret
23
+
24
+ MINUTES = 60 # seconds
25
+ VLLM_PORT = 8000
26
+
27
+ MODEL_NAME = "Qwen/Qwen3-14B"
28
+
29
+
30
+ @app.function(
31
+ image=vllm_image,
32
+ gpu=f"A100-40GB",
33
+ scaledown_window=15 * MINUTES, # how long should we stay up with no requests?
34
+ timeout=10 * MINUTES, # how long should we wait for container start?
35
+ volumes={
36
+ "/root/.cache/huggingface": hf_cache_vol,
37
+ "/root/.cache/vllm": vllm_cache_vol,
38
+ },
39
+ )
40
+ @modal.concurrent(
41
+ max_inputs=10
42
+ ) # how many requests can one replica handle? tune carefully!
43
+ @modal.web_server(port=VLLM_PORT, startup_timeout=5 * MINUTES)
44
+ def serve():
45
+ import subprocess
46
+
47
+ cmd = [
48
+ "vllm",
49
+ "serve",
50
+ "--uvicorn-log-level=info",
51
+ MODEL_NAME,
52
+ "--host",
53
+ "0.0.0.0",
54
+ "--port",
55
+ str(VLLM_PORT),
56
+ "--api-key",
57
+ API_KEY,
58
+ ]
59
+
60
+ subprocess.Popen(" ".join(cmd), shell=True)
61
+
62
+
requirements.txt ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ aiofiles==24.1.0
2
+ aiohappyeyeballs==2.6.1
3
+ aiohttp==3.12.9
4
+ aiosignal==1.3.2
5
+ annotated-types==0.7.0
6
+ anyio==4.9.0
7
+ atpublic==6.0.1
8
+ attrs==25.3.0
9
+ certifi==2025.4.26
10
+ charset-normalizer==3.4.2
11
+ click==8.1.8
12
+ distro==1.9.0
13
+ duckdb==1.3.0
14
+ fastapi==0.115.12
15
+ ffmpy==0.6.0
16
+ filelock==3.18.0
17
+ frozenlist==1.6.2
18
+ fsspec==2025.5.1
19
+ gradio==5.33.0
20
+ gradio-client==1.10.2
21
+ groovy==0.1.2
22
+ grpclib==0.4.7
23
+ h11==0.16.0
24
+ h2==4.2.0
25
+ hf-xet==1.1.3
26
+ hpack==4.1.0
27
+ httpcore==1.0.9
28
+ httpx==0.28.1
29
+ httpx-sse==0.4.0
30
+ huggingface-hub==0.32.4
31
+ hyperframe==6.1.0
32
+ ibis-framework==10.5.0
33
+ idna==3.10
34
+ jinja2==3.1.6
35
+ jiter==0.10.0
36
+ markdown-it-py==3.0.0
37
+ markupsafe==3.0.2
38
+ mcp==1.9.0
39
+ mdurl==0.1.2
40
+ modal==1.0.3
41
+ mpmath==1.3.0
42
+ multidict==6.4.4
43
+ mysqlclient==2.2.7
44
+ networkx==3.5
45
+ numpy==2.2.6
46
+ nvidia-cublas-cu12==12.6.4.1
47
+ nvidia-cuda-cupti-cu12==12.6.80
48
+ nvidia-cuda-nvrtc-cu12==12.6.77
49
+ nvidia-cuda-runtime-cu12==12.6.77
50
+ nvidia-cudnn-cu12==9.5.1.17
51
+ nvidia-cufft-cu12==11.3.0.4
52
+ nvidia-cufile-cu12==1.11.1.6
53
+ nvidia-curand-cu12==10.3.7.77
54
+ nvidia-cusolver-cu12==11.7.1.2
55
+ nvidia-cusparse-cu12==12.5.4.2
56
+ nvidia-cusparselt-cu12==0.6.3
57
+ nvidia-nccl-cu12==2.26.2
58
+ nvidia-nvjitlink-cu12==12.6.85
59
+ nvidia-nvtx-cu12==12.6.77
60
+ openai==1.84.0
61
+ orjson==3.10.18
62
+ packaging==25.0
63
+ pandas==2.2.3
64
+ parsy==2.1
65
+ pillow==11.2.1
66
+ propcache==0.3.1
67
+ protobuf==6.31.1
68
+ pyarrow==20.0.0
69
+ pyarrow-hotfix==0.7
70
+ pydantic==2.11.5
71
+ pydantic-core==2.33.2
72
+ pydantic-settings==2.9.1
73
+ pydub==0.25.1
74
+ pygments==2.19.1
75
+ python-dateutil==2.9.0.post0
76
+ python-dotenv==1.1.0
77
+ python-multipart==0.0.20
78
+ pytz==2025.2
79
+ pyyaml==6.0.2
80
+ requests==2.32.3
81
+ rich==14.0.0
82
+ ruff==0.11.13
83
+ safehttpx==0.1.6
84
+ semantic-version==2.10.0
85
+ setuptools==80.9.0
86
+ shellingham==1.5.4
87
+ sigtools==4.0.1
88
+ six==1.17.0
89
+ sniffio==1.3.1
90
+ sqlglot==26.24.0
91
+ sse-starlette==2.3.6
92
+ starlette==0.46.2
93
+ sympy==1.14.0
94
+ synchronicity==0.9.13
95
+ tabulate==0.9.0
96
+ toml==0.10.2
97
+ tomlkit==0.13.3
98
+ toolz==1.0.0
99
+ tqdm==4.67.1
100
+ triton==3.3.1
101
+ typer==0.16.0
102
+ types-certifi==2021.10.8.3
103
+ types-toml==0.10.8.20240310
104
+ typing-extensions==4.14.0
105
+ typing-inspection==0.4.1
106
+ tzdata==2025.2
107
+ urllib3==2.4.0
108
+ uvicorn==0.34.3
109
+ watchfiles==1.0.5
110
+ websockets==15.0.1
111
+ yarl==1.20.0
services/__init__.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ def generate_sql(client,query,tables):
3
+
4
+ out = "## Database tables\n"
5
+ for table in tables:
6
+ out += table.get('content')
7
+
8
+ prompt = f"Generate an sql query to answer this question {query} \n Based on this database information \n {out} /no_think"
9
+ print(prompt)
10
+ response = client.chat.completions.create(
11
+ model="Qwen/Qwen3-8B", # nom du modèle à utiliser
12
+ messages=[
13
+ {"role": "system", "content": "You are an expert in generating sql query based on a given schema. You will output the generated query in <sql> </sql> tags"},
14
+ {"role": "user", f"content": prompt}
15
+ ],
16
+ temperature=0.7
17
+ )
18
+
19
+ # Affichage du résultat
20
+ txt = response.choices[0].message.content
21
+ match = re.search(r"<sql>(.*?)</sql>", txt, re.DOTALL | re.IGNORECASE)
22
+ if match:
23
+ return match.group(1).strip()
24
+ return None
services/ai.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+ import requests
3
+ import os
4
+ def rank_tables(query,tables):
5
+ task = "Given a user query find the usefull tables in order to build an sql request"
6
+ payload = {
7
+ "task":task,
8
+ "query": query,
9
+ "documents": tables
10
+ }
11
+
12
+ response = requests.post(os.getenv('RERANK_ENDPOINT'), json=payload)
13
+
14
+ if response.status_code == 200:
15
+ results = response.json()
16
+ sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
17
+ return sorted_results
18
+ else:
19
+ raise Exception(f"Request failed: {response.status_code} - {response.text}")
20
+
21
+ def call_llm(client, system_prompt, user_prompt, model="Qwen/Qwen3-14B", temperature=0):
22
+ response = client.chat.completions.create(
23
+ model=model,
24
+ messages=[
25
+ {"role": "system", "content": system_prompt},
26
+ {"role": "user", "content": user_prompt}
27
+ ],
28
+ temperature=temperature
29
+ )
30
+ return response.choices[0].message.content
31
+
32
+ def call_llm_streaming(client, system_prompt, user_prompt,
33
+ model="Qwen/Qwen3-14B", temperature=0):
34
+ # Start streaming chat completion
35
+ stream = client.chat.completions.create(
36
+ model=model,
37
+ messages=[
38
+ {"role": "system", "content": system_prompt},
39
+ {"role": "user", "content": user_prompt}
40
+ ],
41
+ temperature=temperature,
42
+ stream=True
43
+ )
44
+ out = ""
45
+ for chunk in stream:
46
+ delta = chunk.choices[0].delta.content
47
+ if delta:
48
+ out += delta
49
+ yield delta
50
+ return out
51
+
52
+
53
+ def format_thinking(thinking):
54
+ return "" if thinking else "/no_think"
55
+
56
+
57
+ def extract_tagged_content(text, tag):
58
+ pattern = fr"<{tag}>(.*?)</{tag}>"
59
+ match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
60
+ return match.group(1).strip() if match else None
61
+
62
+
63
+ def generate_answer(client, sql, query, result, thinking=False):
64
+ prompt = (
65
+ f"The user asked this question: \n\"{query}\"\n"
66
+ f"The executed SQL was: \n\"{sql}\"\n"
67
+ f"This is the result: \n{result}\n"
68
+ f"Please give an answer to the user "
69
+ f"Use an amical thone and at the end suggest followup questions related to the first user question {format_thinking(thinking)}"
70
+ )
71
+ system = "You are an expert in providing response to questions based on provided content"
72
+ return call_llm_streaming(client, system, prompt)
73
+
74
+
75
+ def evaluate_difficulty(client, query):
76
+ prompt = (
77
+ f"Output a grade between 0 and 10 on how difficult it is to generate an SQL query "
78
+ f"to answer this question:\n{query}\n/no_think"
79
+ )
80
+ system = (
81
+ "You task is to evaluate the level of difficulty for generating an sql query. "
82
+ "You will only output the difficulty level which is between 0 and 10, output in <score></score> tags"
83
+ )
84
+ content = call_llm(client, system, prompt)
85
+ return extract_tagged_content(content, "score")
86
+
87
+
88
+ def generate_sql(client, query, tables, thinking=False):
89
+ schema_info = "## Database tables\n" + "\n".join(tables)
90
+ prompt = (
91
+ f"Generate an SQL query to answer this question: \"{query}\"\n"
92
+ f"Based on this database information:\n{schema_info} {format_thinking(thinking)}"
93
+ )
94
+ system = (
95
+ "You are an expert in generating SQL queries based on a given schema. "
96
+ "You will output the generated query in <sql></sql> tags. "
97
+ "Attention: you can only run one SQL query, so if you need multiple steps, you must use subqueries."
98
+ )
99
+ content = call_llm(client, system, prompt)
100
+ return extract_tagged_content(content, "sql")
101
+
102
+
103
+ def correct_sql(client, question, query, tables, error, thinking=True):
104
+ schema_info = "## Database tables\n" + "\n".join(tables)
105
+ prompt = (
106
+ f"To answer this question: \"{question}\", I tried to run this SQL query:\n{query}\n"
107
+ f"But I got this error:\n{error}\n"
108
+ f"Please take care of the provided schema and give a correct SQL to answer the question. "
109
+ f"Output the query in <sql></sql> tags.\n{schema_info} {format_thinking(thinking)}"
110
+ )
111
+ system = (
112
+ "You are an expert in generating SQL queries based on a given schema. "
113
+ "You will output the generated query in <sql></sql> tags."
114
+ )
115
+ content = call_llm(client, system, prompt)
116
+ return extract_tagged_content(content, "sql")
services/utils.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from urllib.parse import urlparse
2
+
3
+ def filter_tables(threshold, sorted_results):
4
+ return [doc["content"] for doc in sorted_results if doc["score"] > threshold]
5
+
6
+ def get_db_scheme_from_uri(uri: str) -> str:
7
+ """
8
+ Given a SQLAlchemy-style connection URI, return its scheme name
9
+ (with any '+driver' suffix stripped).
10
+
11
+ Examples:
12
+ >>> get_db_scheme_from_uri("postgresql://user:pass@host/db")
13
+ 'postgresql'
14
+
15
+ >>> get_db_scheme_from_uri("postgresql+psycopg2://user:pass@host/db")
16
+ 'postgresql'
17
+
18
+ >>> get_db_scheme_from_uri("duckdb:///path/to/db.duckdb")
19
+ 'duckdb'
20
+ """
21
+ parsed = urlparse(uri)
22
+ scheme = parsed.scheme
23
+ if not scheme:
24
+ raise ValueError(f"No scheme found in URI: {uri!r}")
25
+ # Strip any "+driver" suffix (e.g. "mysql+mysqldb")
26
+ return scheme.split("+", 1)[0]
27
+
28
+ import os
29
+ from urllib.parse import urlparse
30
+
31
+ def extract_filename(path_or_url):
32
+ """
33
+ Extract the file name from a local path or a URL.
34
+
35
+ Args:
36
+ path_or_url (str): The file path or URL.
37
+
38
+ Returns:
39
+ str: The extracted file name.
40
+ """
41
+ parsed = urlparse(path_or_url)
42
+ if parsed.scheme in ('http', 'https', 'ftp'):
43
+ return os.path.basename(parsed.path)
44
+ else:
45
+ return os.path.basename(path_or_url)