Upload 11 files
Browse files- agent.py +62 -0
- app.py +131 -0
- database.py +107 -0
- filesource.py +126 -0
- modal/rerank_service.py +106 -0
- modal/rerank_service_vllm.py +136 -0
- modal/vllm_service.py +62 -0
- requirements.txt +111 -0
- services/__init__.py +24 -0
- services/ai.py +116 -0
- services/utils.py +45 -0
agent.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from services.ai import rank_tables,generate_sql,generate_answer , correct_sql , evaluate_difficulty
|
2 |
+
from services.utils import filter_tables
|
3 |
+
from openai import OpenAI
|
4 |
+
from database import Database
|
5 |
+
from filesource import FileSource
|
6 |
+
import os
|
7 |
+
MAX_TABLE = 3
|
8 |
+
|
9 |
+
client = OpenAI(
|
10 |
+
base_url=os.getenv("LLM_ENDPOINT"),
|
11 |
+
api_key=os.getenv("LLM_KEY")
|
12 |
+
)
|
13 |
+
|
14 |
+
def run_agent(database,prompt):
|
15 |
+
|
16 |
+
retry = 5
|
17 |
+
tables = database.get_tables_array()
|
18 |
+
|
19 |
+
use_thinking = False
|
20 |
+
|
21 |
+
|
22 |
+
if len(tables) > MAX_TABLE:
|
23 |
+
print(f"using reranking because number of tables is greater than {MAX_TABLE}")
|
24 |
+
ranked = rank_tables(prompt,tables)
|
25 |
+
tables = filter_tables(0,ranked)[:MAX_TABLE]
|
26 |
+
|
27 |
+
dif = int(evaluate_difficulty(client,prompt))
|
28 |
+
if dif > 7:
|
29 |
+
print("difficulty is > 7 so we enable thinking mode")
|
30 |
+
use_thinking = True
|
31 |
+
sql = generate_sql(client,prompt,tables,use_thinking)
|
32 |
+
nb_try = 0
|
33 |
+
success = False
|
34 |
+
while nb_try < retry and not success:
|
35 |
+
nb_try = nb_try + 1
|
36 |
+
try:
|
37 |
+
print("try to launch sql request")
|
38 |
+
result = database.query(sql)
|
39 |
+
success = True
|
40 |
+
except Exception as e:
|
41 |
+
print(f"Error : {e}")
|
42 |
+
print("Try to self correct...")
|
43 |
+
error = f"{type(e).__name__} - {str(e)}"
|
44 |
+
if nb_try < retry - 2:
|
45 |
+
sql = correct_sql(client,prompt,sql,tables,error,True)
|
46 |
+
else:
|
47 |
+
sql = correct_sql(client,prompt,sql,tables,error,False)
|
48 |
+
|
49 |
+
print(sql)
|
50 |
+
|
51 |
+
if success:
|
52 |
+
print(result.to_markdown())
|
53 |
+
return generate_answer(client,sql,prompt,result.to_markdown(),use_thinking)
|
54 |
+
|
55 |
+
|
56 |
+
# db = Database("mysql://user:password@localhost:3306/Pokemon")
|
57 |
+
# db.connect()
|
58 |
+
# file = FileSource("./Wines.csv")
|
59 |
+
# file.connect()
|
60 |
+
# print(run_agent(file,"What is the quality og the win with the less of alcohol ?"))
|
61 |
+
|
62 |
+
|
app.py
ADDED
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
from database import Database
|
3 |
+
from filesource import FileSource
|
4 |
+
from agent import run_agent
|
5 |
+
from services.utils import get_db_scheme_from_uri
|
6 |
+
source = None
|
7 |
+
|
8 |
+
|
9 |
+
def connect_to_file(file):
|
10 |
+
global source
|
11 |
+
try:
|
12 |
+
source = FileSource(file.name)
|
13 |
+
status = source.connect()
|
14 |
+
schema = source._pretify_schema()
|
15 |
+
status = "Connection successful!"
|
16 |
+
except Exception as e:
|
17 |
+
schema = ""
|
18 |
+
status = f"Error: {str(e)}"
|
19 |
+
return schema, status
|
20 |
+
|
21 |
+
|
22 |
+
def connect_to_database(db_url):
|
23 |
+
global source
|
24 |
+
try:
|
25 |
+
dialect = get_db_scheme_from_uri(db_url)
|
26 |
+
source = Database(db_url,dialect)
|
27 |
+
status = source.connect()
|
28 |
+
schema = source._pretify_schema()
|
29 |
+
status = "Connection successful!"
|
30 |
+
except Exception as e:
|
31 |
+
schema = ""
|
32 |
+
status = f"Error: {str(e)}"
|
33 |
+
return schema, status
|
34 |
+
|
35 |
+
# Function to add user message to chat history
|
36 |
+
def user(user_message, chat_history):
|
37 |
+
chat_history.append({"role": "user", "content": user_message})
|
38 |
+
return "", chat_history
|
39 |
+
|
40 |
+
# Function to generate a bot response
|
41 |
+
def bot(chat_history):
|
42 |
+
|
43 |
+
if source is None:
|
44 |
+
chat_history.append({"role":"assistant","content":"please connect to a database before asking question"})
|
45 |
+
yield chat_history
|
46 |
+
else:
|
47 |
+
answer = run_agent(source,chat_history[-1]['content'])
|
48 |
+
chat_history.append({"role":"assistant","content":""})
|
49 |
+
|
50 |
+
for chunk in answer:
|
51 |
+
chat_history[-1]['content'] += chunk
|
52 |
+
yield chat_history
|
53 |
+
|
54 |
+
# Create the Gradio interface
|
55 |
+
with gr.Blocks(theme=gr.themes.Default(), css="""
|
56 |
+
.gr-button { margin: 5px; border-radius:16px; }
|
57 |
+
.gr-textbox, .gr-text-area, .gr-dropdown, .gr-json { border-radius: 8px; }
|
58 |
+
.gr-row { gap: 10px; }
|
59 |
+
.gr-tab { border-radius: 8px; }
|
60 |
+
.status-text { font-size: 0.9em; color: #555; }
|
61 |
+
.gr-json { max-height: 300px; overflow-y: auto; } /* Added scrolling for JSON */
|
62 |
+
""") as demo:
|
63 |
+
gr.Markdown(
|
64 |
+
f"""
|
65 |
+
# 🤖 MCP DB Answer
|
66 |
+
Your mcp server that allow you to talk to any database
|
67 |
+
|
68 |
+
|
69 |
+
Powered by Ibis it support : PostgreSQL , SQLite , MySQL , MSSQL , ClickHouse , BigQuery and many other
|
70 |
+
|
71 |
+
|
72 |
+
Also support .CSV and .parquet files
|
73 |
+
|
74 |
+
""",
|
75 |
+
elem_classes=["header"]
|
76 |
+
)
|
77 |
+
|
78 |
+
with gr.Column(scale=3):
|
79 |
+
with gr.Tabs():
|
80 |
+
with gr.TabItem("💬 Chat"):
|
81 |
+
with gr.Group():
|
82 |
+
main_chat_disp = gr.Chatbot(
|
83 |
+
label=None, height=600,
|
84 |
+
avatar_images=(None, "https://huggingface.co/spaces/Space-Share/bucket/resolve/main/images/pfp.webp"),
|
85 |
+
show_copy_button=True, render_markdown=True, sanitize_html=True, type='messages'
|
86 |
+
)
|
87 |
+
with gr.Row(variant="compact"):
|
88 |
+
user_msg_tb = gr.Textbox(
|
89 |
+
show_label=False, placeholder="Talk with your data...",
|
90 |
+
scale=7, lines=1, max_lines=3
|
91 |
+
)
|
92 |
+
send_btn = gr.Button("Send", variant="primary", scale=1, min_width=100)
|
93 |
+
with gr.TabItem("Config"):
|
94 |
+
with gr.Row():
|
95 |
+
# Left column for database configuration.
|
96 |
+
with gr.Column(scale=1):
|
97 |
+
gr.Markdown("## Database Configuration")
|
98 |
+
# Textbox for entering the database URL.
|
99 |
+
db_url_tb = gr.Textbox(
|
100 |
+
show_label=True, label="Database URL", placeholder="Enter the URL to connect to the database..."
|
101 |
+
)
|
102 |
+
# Button to connect to the database.
|
103 |
+
connect_btn = gr.Button("Connect", variant="primary")
|
104 |
+
|
105 |
+
file_uploader = gr.File(
|
106 |
+
label="Upload File", file_types=[".csv", ".parquet", ".xls", ".xlsx"]
|
107 |
+
)
|
108 |
+
# Button to connect to the database.
|
109 |
+
load_btn = gr.Button("Load", variant="primary")
|
110 |
+
|
111 |
+
# Right column for displaying the database schema and status message.
|
112 |
+
with gr.Column(scale=3):
|
113 |
+
gr.Markdown("## Database Schema")
|
114 |
+
# Textarea to display the database schema.
|
115 |
+
schema_ta = gr.TextArea(
|
116 |
+
show_label=False, placeholder="Database schema will be displayed here...",
|
117 |
+
lines=20, max_lines=50, interactive=False
|
118 |
+
)
|
119 |
+
# Textbox to display the status message.
|
120 |
+
status_tb = gr.Textbox(
|
121 |
+
show_label=False, placeholder="Status message will be displayed here...",
|
122 |
+
lines=1, max_lines=1, interactive=False, elem_classes=["status-text"]
|
123 |
+
)
|
124 |
+
connect_btn.click(fn=connect_to_database, inputs=db_url_tb, outputs=[schema_ta, status_tb])
|
125 |
+
load_btn.click(fn=connect_to_file, inputs=file_uploader, outputs=[schema_ta, status_tb])
|
126 |
+
send_btn.click(fn=user, inputs=[user_msg_tb, main_chat_disp], outputs=[user_msg_tb, main_chat_disp], queue=False).then(
|
127 |
+
fn=bot, inputs=main_chat_disp, outputs=main_chat_disp
|
128 |
+
)
|
129 |
+
|
130 |
+
if __name__ == "__main__":
|
131 |
+
demo.launch(mcp_server=True)
|
database.py
ADDED
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ibis
|
2 |
+
import sqlglot
|
3 |
+
from sqlglot import optimizer
|
4 |
+
from sqlglot.optimizer import qualify
|
5 |
+
from sqlglot.errors import OptimizeError, ParseError
|
6 |
+
|
7 |
+
class Database:
|
8 |
+
|
9 |
+
def __init__(self,connection_url,engine_dialect = "mysql") -> None:
|
10 |
+
|
11 |
+
self._connect_url = connection_url
|
12 |
+
self.engine_dialect = engine_dialect
|
13 |
+
self._tables_docs = {}
|
14 |
+
self._table_exemple = {}
|
15 |
+
|
16 |
+
def connect(self):
|
17 |
+
try:
|
18 |
+
self._con = ibis.connect(self._connect_url)
|
19 |
+
return f"✅ Connection to {self._connect_url} OK!"
|
20 |
+
except Exception as e:
|
21 |
+
#raise f"❌ Connection failed: {type(e).__name__} - {str(e)}"
|
22 |
+
raise e
|
23 |
+
|
24 |
+
def _optimize_query(self,sql,schema):
|
25 |
+
|
26 |
+
optimized_expression = optimizer.optimize(sql, schema=schema, dialect=self.engine_dialect)
|
27 |
+
optimized_sql = optimized_expression.sql(dialect=self.engine_dialect)
|
28 |
+
return optimized_sql
|
29 |
+
|
30 |
+
def _pretify_table(self,table,columns):
|
31 |
+
out = ""
|
32 |
+
if table in self._tables_docs.keys():
|
33 |
+
out += f"## Documentation \n{self._tables_docs[table]}\n"
|
34 |
+
|
35 |
+
if table in self._table_exemple.keys():
|
36 |
+
out += f"## Exemple \n{self._table_exemple[table]}"
|
37 |
+
out += f"Table ({table}) with {len(columns)} fields : \n"
|
38 |
+
for field in columns.keys():
|
39 |
+
out += f"\t{field} of type : {columns[field]}\n"
|
40 |
+
return out
|
41 |
+
|
42 |
+
def add_table_documentation(self,table_name,documentation):
|
43 |
+
self._tables_docs[table_name] = documentation
|
44 |
+
def add_table_exemple(self,table_name,exemples):
|
45 |
+
self._table_exemple[table_name] = exemples
|
46 |
+
|
47 |
+
def get_tables_array(self):
|
48 |
+
schema = self._build_schema()
|
49 |
+
array = []
|
50 |
+
for table in schema.keys():
|
51 |
+
array.append(self._pretify_table(table,schema[table]))
|
52 |
+
return array
|
53 |
+
|
54 |
+
def _pretify_schema(self):
|
55 |
+
out = ""
|
56 |
+
schema = self._build_schema()
|
57 |
+
for table in schema.keys():
|
58 |
+
out += self._pretify_table(table,schema[table])
|
59 |
+
out += "\n"
|
60 |
+
return out
|
61 |
+
def _build_schema(self):
|
62 |
+
|
63 |
+
tables = self._con.list_tables()
|
64 |
+
schema = {}
|
65 |
+
for table_name in tables:
|
66 |
+
|
67 |
+
try:
|
68 |
+
table_expr = self._con.table(table_name)
|
69 |
+
table_schema = table_expr.schema()
|
70 |
+
columns = {col: str(dtype) for col, dtype in table_schema.items()}
|
71 |
+
schema[table_name] = columns
|
72 |
+
|
73 |
+
except Exception as e:
|
74 |
+
|
75 |
+
print(f"Warning: Could not retrieve schema for table '{table_name}': {e}")
|
76 |
+
return schema
|
77 |
+
|
78 |
+
def query(self, sql_query):
|
79 |
+
schema = self._build_schema()
|
80 |
+
print(sql_query)
|
81 |
+
try:
|
82 |
+
expression = sqlglot.parse_one(sql_query, read=self.engine_dialect)
|
83 |
+
except Exception as e:
|
84 |
+
raise e
|
85 |
+
|
86 |
+
try:
|
87 |
+
optimized_query = self._optimize_query(expression, schema)
|
88 |
+
final_query = optimized_query
|
89 |
+
except Exception as e:
|
90 |
+
final_query = expression.sql(dialect=self.engine_dialect)
|
91 |
+
|
92 |
+
try:
|
93 |
+
expr = self._con.sql(final_query, dialect=self.engine_dialect)
|
94 |
+
result_df = expr.execute()
|
95 |
+
return result_df
|
96 |
+
except Exception as e:
|
97 |
+
raise e
|
98 |
+
|
99 |
+
|
100 |
+
# db = Database("mysql://user:password@localhost:3306/Pokemon")
|
101 |
+
# db.connect()
|
102 |
+
# schema = db._build_schema()
|
103 |
+
# db.add_table_documentation("Defense","This is a super table")
|
104 |
+
# db.add_table_exemple("Defense","caca")
|
105 |
+
# db.add_table_exemple("Joueur","ezofkzrfp")
|
106 |
+
# for table in schema.keys():
|
107 |
+
# print(db._pretify_table(table,schema[table]))
|
filesource.py
ADDED
@@ -0,0 +1,126 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import ibis
|
2 |
+
import sqlglot
|
3 |
+
from sqlglot import optimizer
|
4 |
+
from sqlglot.optimizer import qualify
|
5 |
+
from sqlglot.errors import OptimizeError, ParseError
|
6 |
+
from services.utils import extract_filename
|
7 |
+
class FileSource:
|
8 |
+
|
9 |
+
def __init__(self,file_path,file_type="csv") -> None:
|
10 |
+
self.file_path = file_path
|
11 |
+
self.file_type = file_type.lower()
|
12 |
+
self._tables_docs = {}
|
13 |
+
self._table_exemple = {}
|
14 |
+
self.engine_dialect = "duckdb"
|
15 |
+
|
16 |
+
|
17 |
+
def connect(self):
|
18 |
+
try:
|
19 |
+
self._con = ibis.connect("duckdb://")
|
20 |
+
name = extract_filename(self.file_path)
|
21 |
+
ext = name.split(".")[1]
|
22 |
+
table = name.split(".")[0]
|
23 |
+
if ext == "csv":
|
24 |
+
self._table = self._con.read_csv(self.file_path,table_name=table)
|
25 |
+
elif ext == "parquet":
|
26 |
+
self._table = self._con.read_parquet(self.file_path,table_name=table)
|
27 |
+
self._schema = self._table.schema()
|
28 |
+
return f"✅ Connection to {name} OK!"
|
29 |
+
except Exception as e:
|
30 |
+
raise e
|
31 |
+
|
32 |
+
def _optimize_query(self,sql,schema):
|
33 |
+
|
34 |
+
optimized_expression = optimizer.optimize(sql, schema=schema, dialect=self.engine_dialect)
|
35 |
+
optimized_sql = optimized_expression.sql(dialect=self.engine_dialect)
|
36 |
+
return optimized_sql
|
37 |
+
|
38 |
+
def _pretify_table(self,table,columns):
|
39 |
+
out = ""
|
40 |
+
if table in self._tables_docs.keys():
|
41 |
+
out += f"## Documentation \n{self._tables_docs[table]}\n"
|
42 |
+
|
43 |
+
if table in self._table_exemple.keys():
|
44 |
+
out += f"## Exemple \n{self._table_exemple[table]}"
|
45 |
+
out += f"Table ({table}) with {len(columns)} fields : \n"
|
46 |
+
for field in columns.keys():
|
47 |
+
out += f"\t{field} of type : {columns[field]}\n"
|
48 |
+
return out
|
49 |
+
|
50 |
+
def add_table_documentation(self,table_name,documentation):
|
51 |
+
self._tables_docs[table_name] = documentation
|
52 |
+
def add_table_exemple(self,table_name,exemples):
|
53 |
+
self._table_exemple[table_name] = exemples
|
54 |
+
|
55 |
+
def get_tables_array(self):
|
56 |
+
schema = self._build_schema()
|
57 |
+
array = []
|
58 |
+
for table in schema.keys():
|
59 |
+
array.append(self._pretify_table(table,schema[table]))
|
60 |
+
return array
|
61 |
+
|
62 |
+
def _pretify_schema(self):
|
63 |
+
out = ""
|
64 |
+
schema = self._build_schema()
|
65 |
+
for table in schema.keys():
|
66 |
+
out += self._pretify_table(table,schema[table])
|
67 |
+
out += "\n"
|
68 |
+
return out
|
69 |
+
def _build_schema(self):
|
70 |
+
|
71 |
+
tables = self._con.list_tables()
|
72 |
+
schema = {}
|
73 |
+
for table_name in tables:
|
74 |
+
|
75 |
+
try:
|
76 |
+
table_expr = self._con.table(table_name)
|
77 |
+
table_schema = table_expr.schema()
|
78 |
+
columns = {col: str(dtype) for col, dtype in table_schema.items()}
|
79 |
+
schema[table_name] = columns
|
80 |
+
|
81 |
+
except Exception as e:
|
82 |
+
|
83 |
+
print(f"Warning: Could not retrieve schema for table '{table_name}': {e}")
|
84 |
+
return schema
|
85 |
+
|
86 |
+
def query(self, sql_query):
|
87 |
+
schema = self._build_schema()
|
88 |
+
print(sql_query)
|
89 |
+
try:
|
90 |
+
expression = sqlglot.parse_one(sql_query, read=self.engine_dialect)
|
91 |
+
except Exception as e:
|
92 |
+
raise e
|
93 |
+
|
94 |
+
try:
|
95 |
+
optimized_query = self._optimize_query(expression, schema)
|
96 |
+
final_query = optimized_query
|
97 |
+
except Exception as e:
|
98 |
+
final_query = expression.sql(dialect=self.engine_dialect)
|
99 |
+
|
100 |
+
try:
|
101 |
+
expr = self._con.sql(final_query, dialect=self.engine_dialect)
|
102 |
+
result_df = expr.execute()
|
103 |
+
return result_df
|
104 |
+
except Exception as e:
|
105 |
+
raise e
|
106 |
+
|
107 |
+
|
108 |
+
# db = Database("mysql://user:password@localhost:3306/Pokemon")
|
109 |
+
# db.connect()
|
110 |
+
# schema = db._build_schema()
|
111 |
+
# db.add_table_documentation("Defense","This is a super table")
|
112 |
+
# db.add_table_exemple("Defense","caca")
|
113 |
+
# db.add_table_exemple("Joueur","ezofkzrfp")
|
114 |
+
# for table in schema.keys():
|
115 |
+
# print(db._pretify_table(table,schema[table]))
|
116 |
+
# file = FileSource("./Wines.csv")
|
117 |
+
# file.connect()
|
118 |
+
|
119 |
+
# schema = file._build_schema()
|
120 |
+
# # db.add_table_exemple("Defense","caca")
|
121 |
+
# # db.add_table_exemple("Joueur","ezofkzrfp")
|
122 |
+
# for table in schema.keys():
|
123 |
+
# print(file._pretify_table(table,schema[table]))
|
124 |
+
|
125 |
+
# res = file.query("SELECT * FROM Wines;")
|
126 |
+
# print(len(res))
|
modal/rerank_service.py
ADDED
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
MINUTES = 60 # seconds
|
3 |
+
MODEL_REPO_ID = "Qwen/Qwen3-Reranker-4B"
|
4 |
+
rerank_image = (
|
5 |
+
modal.Image.debian_slim(python_version="3.12")
|
6 |
+
.pip_install(
|
7 |
+
"transformers==4.51.0",
|
8 |
+
"huggingface_hub[hf_transfer]",
|
9 |
+
"fastapi[standard]",
|
10 |
+
"torch"
|
11 |
+
)
|
12 |
+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
13 |
+
)
|
14 |
+
|
15 |
+
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
|
16 |
+
|
17 |
+
with rerank_image.imports():
|
18 |
+
import torch
|
19 |
+
from transformers import AutoModel, AutoTokenizer, AutoModelForCausalLM
|
20 |
+
|
21 |
+
app = modal.App("qwen3-rerank-service")
|
22 |
+
|
23 |
+
@app.function(image=rerank_image,volumes = {
|
24 |
+
"/root/.cache/huggingface":hf_cache_vol
|
25 |
+
})
|
26 |
+
def download_model():
|
27 |
+
from huggingface_hub import snapshot_download
|
28 |
+
loc = snapshot_download(repo_id=MODEL_REPO_ID)
|
29 |
+
print(f"Saved model to {loc}")
|
30 |
+
|
31 |
+
@app.cls(image=rerank_image,gpu="A100-40GB",volumes = {
|
32 |
+
"/root/.cache/huggingface":hf_cache_vol
|
33 |
+
})
|
34 |
+
class RerankerService:
|
35 |
+
|
36 |
+
@modal.enter()
|
37 |
+
def load_model(self):
|
38 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B", padding_side='left')
|
39 |
+
self.model = AutoModelForCausalLM.from_pretrained("Qwen/Qwen3-Reranker-4B", torch_dtype=torch.float16, attn_implementation="flash_attention_2").cuda().eval()
|
40 |
+
|
41 |
+
@modal.method()
|
42 |
+
def rank(self,query,documents):
|
43 |
+
max_length = 8192
|
44 |
+
prefix = "<|im_start|>system\nJudge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\".<|im_end|>\n<|im_start|>user\n"
|
45 |
+
suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
46 |
+
prefix_tokens = self.tokenizer.encode(prefix, add_special_tokens=False)
|
47 |
+
suffix_tokens = self.tokenizer.encode(suffix, add_special_tokens=False)
|
48 |
+
token_false_id = self.tokenizer.convert_tokens_to_ids("no")
|
49 |
+
token_true_id = self.tokenizer.convert_tokens_to_ids("yes")
|
50 |
+
def format_instruction(instruction, query, doc):
|
51 |
+
if instruction is None:
|
52 |
+
instruction = 'Given a web search query, retrieve relevant passages that answer the query'
|
53 |
+
return f"<Instruct>: {instruction}\n<Query>: {query}\n<Document>: {doc}"
|
54 |
+
|
55 |
+
def process_inputs(pairs):
|
56 |
+
inputs = self.tokenizer(
|
57 |
+
pairs, padding=False, truncation='longest_first',
|
58 |
+
return_attention_mask=False, max_length=max_length - len(prefix_tokens) - len(suffix_tokens)
|
59 |
+
)
|
60 |
+
for i, ele in enumerate(inputs['input_ids']):
|
61 |
+
inputs['input_ids'][i] = prefix_tokens + ele + suffix_tokens
|
62 |
+
inputs = self.tokenizer.pad(inputs, padding=True, return_tensors="pt", max_length=max_length)
|
63 |
+
for key in inputs:
|
64 |
+
inputs[key] = inputs[key].to(self.model.device)
|
65 |
+
return inputs
|
66 |
+
|
67 |
+
@torch.no_grad()
|
68 |
+
def compute_logits(inputs):
|
69 |
+
logits = self.model(**inputs).logits[:, -1, :]
|
70 |
+
true_vector = logits[:, token_true_id]
|
71 |
+
false_vector = logits[:, token_false_id]
|
72 |
+
batch_scores = torch.stack([false_vector, true_vector], dim=1)
|
73 |
+
batch_scores = torch.nn.functional.log_softmax(batch_scores, dim=1)
|
74 |
+
return batch_scores[:, 1].exp().tolist()
|
75 |
+
|
76 |
+
instruction = "Given a user query find the usefull tables in order to build an sql request"
|
77 |
+
pairs = [format_instruction(instruction, query, doc) for doc in documents]
|
78 |
+
|
79 |
+
inputs = process_inputs(pairs)
|
80 |
+
scores = compute_logits(inputs)
|
81 |
+
|
82 |
+
return scores
|
83 |
+
|
84 |
+
@app.function(
|
85 |
+
image=modal.Image.debian_slim(python_version="3.12")
|
86 |
+
.pip_install("fastapi[standard]==0.115.4")
|
87 |
+
)
|
88 |
+
@modal.asgi_app(label="rerank-endpoint")
|
89 |
+
def fastapi_app():
|
90 |
+
from fastapi import FastAPI, Request, Response
|
91 |
+
from fastapi.staticfiles import StaticFiles
|
92 |
+
|
93 |
+
web_app = FastAPI()
|
94 |
+
|
95 |
+
# The endpoint for the prediction function takes an image as a
|
96 |
+
# [data URI](https://en.wikipedia.org/wiki/Data_URI_scheme)
|
97 |
+
# and returns another image, also as a data URI:
|
98 |
+
|
99 |
+
@web_app.post("/predict")
|
100 |
+
async def predict(request: Request):
|
101 |
+
# Takes a webcam image as a datauri, returns a bounding box image as a datauri
|
102 |
+
body = await request.body()
|
103 |
+
output_data = RerankerService().rank.remote("What is the capital of China?",["The capital of China is Beijing.","Gravity is a force that attracts two bodies towards each other. It gives weight to physical objects and is responsible for the movement of planets around the sun.",])
|
104 |
+
return Response(content=output_data)
|
105 |
+
|
106 |
+
return web_app
|
modal/rerank_service_vllm.py
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
import logging
|
3 |
+
app = modal.App("qwen-reranker-vllm")
|
4 |
+
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
|
5 |
+
vllm_cache_vol = modal.Volume.from_name("vllm-cache")
|
6 |
+
MINUTES = 60 # seconds
|
7 |
+
|
8 |
+
vllm_image = (
|
9 |
+
modal.Image.debian_slim(python_version="3.12")
|
10 |
+
.pip_install(
|
11 |
+
"vllm==0.8.5",
|
12 |
+
"transformers",
|
13 |
+
"torch",
|
14 |
+
"fastapi[all]",
|
15 |
+
"pydantic"
|
16 |
+
)
|
17 |
+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"})
|
18 |
+
)
|
19 |
+
|
20 |
+
|
21 |
+
with vllm_image.imports():
|
22 |
+
from transformers import AutoTokenizer
|
23 |
+
from vllm import LLM, SamplingParams
|
24 |
+
from vllm.inputs.data import TokensPrompt
|
25 |
+
import torch
|
26 |
+
import math
|
27 |
+
|
28 |
+
@app.cls(image=vllm_image,
|
29 |
+
gpu="A100-40GB",
|
30 |
+
scaledown_window=15 * MINUTES, # how long should we stay up with no requests?
|
31 |
+
timeout=10 * MINUTES,
|
32 |
+
volumes = {
|
33 |
+
"/root/.cache/huggingface":hf_cache_vol,
|
34 |
+
"/root/.cache/vllm": vllm_cache_vol,
|
35 |
+
})
|
36 |
+
class Reranker:
|
37 |
+
@modal.enter()
|
38 |
+
def load_reranker(self):
|
39 |
+
logging.info("in the rank function")
|
40 |
+
self.tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-Reranker-4B")
|
41 |
+
self.tokenizer.padding_side = "left"
|
42 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
43 |
+
self.model = LLM(
|
44 |
+
model="Qwen/Qwen3-Reranker-4B",
|
45 |
+
tensor_parallel_size=torch.cuda.device_count(),
|
46 |
+
max_model_len=10000,
|
47 |
+
enable_prefix_caching=True,
|
48 |
+
gpu_memory_utilization=0.8
|
49 |
+
)
|
50 |
+
self.suffix = "<|im_end|>\n<|im_start|>assistant\n<think>\n\n</think>\n\n"
|
51 |
+
self.suffix_tokens = self.tokenizer.encode(self.suffix, add_special_tokens=False)
|
52 |
+
self.max_length = 8192
|
53 |
+
self.true_token = self.tokenizer("yes", add_special_tokens=False).input_ids[0]
|
54 |
+
self.false_token = self.tokenizer("no", add_special_tokens=False).input_ids[0]
|
55 |
+
self.sampling_params = SamplingParams(
|
56 |
+
temperature=0,
|
57 |
+
max_tokens=1,
|
58 |
+
logprobs=20,
|
59 |
+
allowed_token_ids=[self.true_token, self.false_token],
|
60 |
+
)
|
61 |
+
|
62 |
+
def format_instruction(self, instruction, query, doc):
|
63 |
+
return [
|
64 |
+
{"role": "system", "content": "Judge whether the Table will be usefull to create an sql request to answer the Query. Note that the answer can only be \"yes\" or \"no\""},
|
65 |
+
{"role": "user", "content": f"<Instruct>: {instruction}\n\n<Query>: {query}\n\n<Document>: {doc}"}
|
66 |
+
]
|
67 |
+
|
68 |
+
def process_inputs(self,pairs, instruction):
|
69 |
+
messages = [self.format_instruction(instruction, query, doc) for query, doc in pairs]
|
70 |
+
messages = self.tokenizer.apply_chat_template(
|
71 |
+
messages, tokenize=True, add_generation_prompt=False, enable_thinking=False
|
72 |
+
)
|
73 |
+
messages = [ele[:self.max_length] + self.suffix_tokens for ele in messages]
|
74 |
+
messages = [TokensPrompt(prompt_token_ids=ele) for ele in messages]
|
75 |
+
return messages
|
76 |
+
|
77 |
+
def compute_logits(self, messages):
|
78 |
+
outputs = self.model.generate(messages, self.sampling_params, use_tqdm=False)
|
79 |
+
scores = []
|
80 |
+
for i in range(len(outputs)):
|
81 |
+
final_logits = outputs[i].outputs[0].logprobs[-1]
|
82 |
+
token_count = len(outputs[i].outputs[0].token_ids)
|
83 |
+
if self.true_token not in final_logits:
|
84 |
+
true_logit = -10
|
85 |
+
else:
|
86 |
+
true_logit = final_logits[self.true_token].logprob
|
87 |
+
if self.false_token not in final_logits:
|
88 |
+
false_logit = -10
|
89 |
+
else:
|
90 |
+
false_logit = final_logits[self.false_token].logprob
|
91 |
+
true_score = math.exp(true_logit)
|
92 |
+
false_score = math.exp(false_logit)
|
93 |
+
score = true_score / (true_score + false_score)
|
94 |
+
scores.append(score)
|
95 |
+
return scores
|
96 |
+
|
97 |
+
@modal.method()
|
98 |
+
def rerank(self, query, documents,task):
|
99 |
+
#task = 'Given a web search query, retrieve relevant passages that answer the query'
|
100 |
+
pairs = [(query, doc) for doc in documents]
|
101 |
+
inputs = self.process_inputs(pairs, task)
|
102 |
+
scores = self.compute_logits( inputs)
|
103 |
+
|
104 |
+
return [{"score": float(score), "content": doc} for score, doc in zip(scores, documents)]
|
105 |
+
|
106 |
+
@app.function(
|
107 |
+
image=modal.Image.debian_slim(python_version="3.12")
|
108 |
+
.pip_install("fastapi[standard]==0.115.4","pydantic")
|
109 |
+
)
|
110 |
+
@modal.asgi_app(label="rerank-endpoint")
|
111 |
+
def fastapi_app():
|
112 |
+
from pydantic import BaseModel
|
113 |
+
from fastapi import FastAPI, Request, Response
|
114 |
+
from fastapi.responses import JSONResponse
|
115 |
+
from typing import List
|
116 |
+
|
117 |
+
web_app = FastAPI()
|
118 |
+
reranker = Reranker()
|
119 |
+
class ScoringResult(BaseModel):
|
120 |
+
score: float
|
121 |
+
content: str
|
122 |
+
|
123 |
+
class RankingRequest(BaseModel):
|
124 |
+
task:str
|
125 |
+
query: str
|
126 |
+
documents: List[str]
|
127 |
+
@web_app.post("/rank",response_model=List[ScoringResult])
|
128 |
+
async def predict(payload: RankingRequest):
|
129 |
+
logging.info("call the rank function")
|
130 |
+
query = payload.query
|
131 |
+
documents = payload.documents
|
132 |
+
task = payload.task
|
133 |
+
output_data = reranker.rerank.remote(query,documents,task)
|
134 |
+
return JSONResponse(content=output_data)
|
135 |
+
|
136 |
+
return web_app
|
modal/vllm_service.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import modal
|
2 |
+
|
3 |
+
vllm_image = (
|
4 |
+
modal.Image.debian_slim(python_version="3.12")
|
5 |
+
.pip_install(
|
6 |
+
"vllm==0.7.2",
|
7 |
+
"transformers==4.51.0",
|
8 |
+
"huggingface_hub[hf_transfer]",
|
9 |
+
"flashinfer-python==0.2.0.post2",
|
10 |
+
extra_index_url="https://flashinfer.ai/whl/cu124/torch2.5",
|
11 |
+
)
|
12 |
+
.env({"HF_HUB_ENABLE_HF_TRANSFER": "1"}) # faster model transfers
|
13 |
+
)
|
14 |
+
vllm_image = vllm_image.env({"VLLM_USE_V1": "1"})
|
15 |
+
|
16 |
+
hf_cache_vol = modal.Volume.from_name("mcp-datascientist-model-weights-vol")
|
17 |
+
vllm_cache_vol = modal.Volume.from_name("vllm-cache", create_if_missing=True)
|
18 |
+
|
19 |
+
app = modal.App("example-vllm-openai-compatible")
|
20 |
+
|
21 |
+
N_GPU = 1 # tip: for best results, first upgrade to more powerful GPUs, and only then increase GPU count
|
22 |
+
API_KEY = "super-secret-key-mcp-hackathon" # api key, for auth. for production use, replace with a modal.Secret
|
23 |
+
|
24 |
+
MINUTES = 60 # seconds
|
25 |
+
VLLM_PORT = 8000
|
26 |
+
|
27 |
+
MODEL_NAME = "Qwen/Qwen3-14B"
|
28 |
+
|
29 |
+
|
30 |
+
@app.function(
|
31 |
+
image=vllm_image,
|
32 |
+
gpu=f"A100-40GB",
|
33 |
+
scaledown_window=15 * MINUTES, # how long should we stay up with no requests?
|
34 |
+
timeout=10 * MINUTES, # how long should we wait for container start?
|
35 |
+
volumes={
|
36 |
+
"/root/.cache/huggingface": hf_cache_vol,
|
37 |
+
"/root/.cache/vllm": vllm_cache_vol,
|
38 |
+
},
|
39 |
+
)
|
40 |
+
@modal.concurrent(
|
41 |
+
max_inputs=10
|
42 |
+
) # how many requests can one replica handle? tune carefully!
|
43 |
+
@modal.web_server(port=VLLM_PORT, startup_timeout=5 * MINUTES)
|
44 |
+
def serve():
|
45 |
+
import subprocess
|
46 |
+
|
47 |
+
cmd = [
|
48 |
+
"vllm",
|
49 |
+
"serve",
|
50 |
+
"--uvicorn-log-level=info",
|
51 |
+
MODEL_NAME,
|
52 |
+
"--host",
|
53 |
+
"0.0.0.0",
|
54 |
+
"--port",
|
55 |
+
str(VLLM_PORT),
|
56 |
+
"--api-key",
|
57 |
+
API_KEY,
|
58 |
+
]
|
59 |
+
|
60 |
+
subprocess.Popen(" ".join(cmd), shell=True)
|
61 |
+
|
62 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
aiofiles==24.1.0
|
2 |
+
aiohappyeyeballs==2.6.1
|
3 |
+
aiohttp==3.12.9
|
4 |
+
aiosignal==1.3.2
|
5 |
+
annotated-types==0.7.0
|
6 |
+
anyio==4.9.0
|
7 |
+
atpublic==6.0.1
|
8 |
+
attrs==25.3.0
|
9 |
+
certifi==2025.4.26
|
10 |
+
charset-normalizer==3.4.2
|
11 |
+
click==8.1.8
|
12 |
+
distro==1.9.0
|
13 |
+
duckdb==1.3.0
|
14 |
+
fastapi==0.115.12
|
15 |
+
ffmpy==0.6.0
|
16 |
+
filelock==3.18.0
|
17 |
+
frozenlist==1.6.2
|
18 |
+
fsspec==2025.5.1
|
19 |
+
gradio==5.33.0
|
20 |
+
gradio-client==1.10.2
|
21 |
+
groovy==0.1.2
|
22 |
+
grpclib==0.4.7
|
23 |
+
h11==0.16.0
|
24 |
+
h2==4.2.0
|
25 |
+
hf-xet==1.1.3
|
26 |
+
hpack==4.1.0
|
27 |
+
httpcore==1.0.9
|
28 |
+
httpx==0.28.1
|
29 |
+
httpx-sse==0.4.0
|
30 |
+
huggingface-hub==0.32.4
|
31 |
+
hyperframe==6.1.0
|
32 |
+
ibis-framework==10.5.0
|
33 |
+
idna==3.10
|
34 |
+
jinja2==3.1.6
|
35 |
+
jiter==0.10.0
|
36 |
+
markdown-it-py==3.0.0
|
37 |
+
markupsafe==3.0.2
|
38 |
+
mcp==1.9.0
|
39 |
+
mdurl==0.1.2
|
40 |
+
modal==1.0.3
|
41 |
+
mpmath==1.3.0
|
42 |
+
multidict==6.4.4
|
43 |
+
mysqlclient==2.2.7
|
44 |
+
networkx==3.5
|
45 |
+
numpy==2.2.6
|
46 |
+
nvidia-cublas-cu12==12.6.4.1
|
47 |
+
nvidia-cuda-cupti-cu12==12.6.80
|
48 |
+
nvidia-cuda-nvrtc-cu12==12.6.77
|
49 |
+
nvidia-cuda-runtime-cu12==12.6.77
|
50 |
+
nvidia-cudnn-cu12==9.5.1.17
|
51 |
+
nvidia-cufft-cu12==11.3.0.4
|
52 |
+
nvidia-cufile-cu12==1.11.1.6
|
53 |
+
nvidia-curand-cu12==10.3.7.77
|
54 |
+
nvidia-cusolver-cu12==11.7.1.2
|
55 |
+
nvidia-cusparse-cu12==12.5.4.2
|
56 |
+
nvidia-cusparselt-cu12==0.6.3
|
57 |
+
nvidia-nccl-cu12==2.26.2
|
58 |
+
nvidia-nvjitlink-cu12==12.6.85
|
59 |
+
nvidia-nvtx-cu12==12.6.77
|
60 |
+
openai==1.84.0
|
61 |
+
orjson==3.10.18
|
62 |
+
packaging==25.0
|
63 |
+
pandas==2.2.3
|
64 |
+
parsy==2.1
|
65 |
+
pillow==11.2.1
|
66 |
+
propcache==0.3.1
|
67 |
+
protobuf==6.31.1
|
68 |
+
pyarrow==20.0.0
|
69 |
+
pyarrow-hotfix==0.7
|
70 |
+
pydantic==2.11.5
|
71 |
+
pydantic-core==2.33.2
|
72 |
+
pydantic-settings==2.9.1
|
73 |
+
pydub==0.25.1
|
74 |
+
pygments==2.19.1
|
75 |
+
python-dateutil==2.9.0.post0
|
76 |
+
python-dotenv==1.1.0
|
77 |
+
python-multipart==0.0.20
|
78 |
+
pytz==2025.2
|
79 |
+
pyyaml==6.0.2
|
80 |
+
requests==2.32.3
|
81 |
+
rich==14.0.0
|
82 |
+
ruff==0.11.13
|
83 |
+
safehttpx==0.1.6
|
84 |
+
semantic-version==2.10.0
|
85 |
+
setuptools==80.9.0
|
86 |
+
shellingham==1.5.4
|
87 |
+
sigtools==4.0.1
|
88 |
+
six==1.17.0
|
89 |
+
sniffio==1.3.1
|
90 |
+
sqlglot==26.24.0
|
91 |
+
sse-starlette==2.3.6
|
92 |
+
starlette==0.46.2
|
93 |
+
sympy==1.14.0
|
94 |
+
synchronicity==0.9.13
|
95 |
+
tabulate==0.9.0
|
96 |
+
toml==0.10.2
|
97 |
+
tomlkit==0.13.3
|
98 |
+
toolz==1.0.0
|
99 |
+
tqdm==4.67.1
|
100 |
+
triton==3.3.1
|
101 |
+
typer==0.16.0
|
102 |
+
types-certifi==2021.10.8.3
|
103 |
+
types-toml==0.10.8.20240310
|
104 |
+
typing-extensions==4.14.0
|
105 |
+
typing-inspection==0.4.1
|
106 |
+
tzdata==2025.2
|
107 |
+
urllib3==2.4.0
|
108 |
+
uvicorn==0.34.3
|
109 |
+
watchfiles==1.0.5
|
110 |
+
websockets==15.0.1
|
111 |
+
yarl==1.20.0
|
services/__init__.py
ADDED
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
def generate_sql(client,query,tables):
|
3 |
+
|
4 |
+
out = "## Database tables\n"
|
5 |
+
for table in tables:
|
6 |
+
out += table.get('content')
|
7 |
+
|
8 |
+
prompt = f"Generate an sql query to answer this question {query} \n Based on this database information \n {out} /no_think"
|
9 |
+
print(prompt)
|
10 |
+
response = client.chat.completions.create(
|
11 |
+
model="Qwen/Qwen3-8B", # nom du modèle à utiliser
|
12 |
+
messages=[
|
13 |
+
{"role": "system", "content": "You are an expert in generating sql query based on a given schema. You will output the generated query in <sql> </sql> tags"},
|
14 |
+
{"role": "user", f"content": prompt}
|
15 |
+
],
|
16 |
+
temperature=0.7
|
17 |
+
)
|
18 |
+
|
19 |
+
# Affichage du résultat
|
20 |
+
txt = response.choices[0].message.content
|
21 |
+
match = re.search(r"<sql>(.*?)</sql>", txt, re.DOTALL | re.IGNORECASE)
|
22 |
+
if match:
|
23 |
+
return match.group(1).strip()
|
24 |
+
return None
|
services/ai.py
ADDED
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import requests
|
3 |
+
import os
|
4 |
+
def rank_tables(query,tables):
|
5 |
+
task = "Given a user query find the usefull tables in order to build an sql request"
|
6 |
+
payload = {
|
7 |
+
"task":task,
|
8 |
+
"query": query,
|
9 |
+
"documents": tables
|
10 |
+
}
|
11 |
+
|
12 |
+
response = requests.post(os.getenv('RERANK_ENDPOINT'), json=payload)
|
13 |
+
|
14 |
+
if response.status_code == 200:
|
15 |
+
results = response.json()
|
16 |
+
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
|
17 |
+
return sorted_results
|
18 |
+
else:
|
19 |
+
raise Exception(f"Request failed: {response.status_code} - {response.text}")
|
20 |
+
|
21 |
+
def call_llm(client, system_prompt, user_prompt, model="Qwen/Qwen3-14B", temperature=0):
|
22 |
+
response = client.chat.completions.create(
|
23 |
+
model=model,
|
24 |
+
messages=[
|
25 |
+
{"role": "system", "content": system_prompt},
|
26 |
+
{"role": "user", "content": user_prompt}
|
27 |
+
],
|
28 |
+
temperature=temperature
|
29 |
+
)
|
30 |
+
return response.choices[0].message.content
|
31 |
+
|
32 |
+
def call_llm_streaming(client, system_prompt, user_prompt,
|
33 |
+
model="Qwen/Qwen3-14B", temperature=0):
|
34 |
+
# Start streaming chat completion
|
35 |
+
stream = client.chat.completions.create(
|
36 |
+
model=model,
|
37 |
+
messages=[
|
38 |
+
{"role": "system", "content": system_prompt},
|
39 |
+
{"role": "user", "content": user_prompt}
|
40 |
+
],
|
41 |
+
temperature=temperature,
|
42 |
+
stream=True
|
43 |
+
)
|
44 |
+
out = ""
|
45 |
+
for chunk in stream:
|
46 |
+
delta = chunk.choices[0].delta.content
|
47 |
+
if delta:
|
48 |
+
out += delta
|
49 |
+
yield delta
|
50 |
+
return out
|
51 |
+
|
52 |
+
|
53 |
+
def format_thinking(thinking):
|
54 |
+
return "" if thinking else "/no_think"
|
55 |
+
|
56 |
+
|
57 |
+
def extract_tagged_content(text, tag):
|
58 |
+
pattern = fr"<{tag}>(.*?)</{tag}>"
|
59 |
+
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
|
60 |
+
return match.group(1).strip() if match else None
|
61 |
+
|
62 |
+
|
63 |
+
def generate_answer(client, sql, query, result, thinking=False):
|
64 |
+
prompt = (
|
65 |
+
f"The user asked this question: \n\"{query}\"\n"
|
66 |
+
f"The executed SQL was: \n\"{sql}\"\n"
|
67 |
+
f"This is the result: \n{result}\n"
|
68 |
+
f"Please give an answer to the user "
|
69 |
+
f"Use an amical thone and at the end suggest followup questions related to the first user question {format_thinking(thinking)}"
|
70 |
+
)
|
71 |
+
system = "You are an expert in providing response to questions based on provided content"
|
72 |
+
return call_llm_streaming(client, system, prompt)
|
73 |
+
|
74 |
+
|
75 |
+
def evaluate_difficulty(client, query):
|
76 |
+
prompt = (
|
77 |
+
f"Output a grade between 0 and 10 on how difficult it is to generate an SQL query "
|
78 |
+
f"to answer this question:\n{query}\n/no_think"
|
79 |
+
)
|
80 |
+
system = (
|
81 |
+
"You task is to evaluate the level of difficulty for generating an sql query. "
|
82 |
+
"You will only output the difficulty level which is between 0 and 10, output in <score></score> tags"
|
83 |
+
)
|
84 |
+
content = call_llm(client, system, prompt)
|
85 |
+
return extract_tagged_content(content, "score")
|
86 |
+
|
87 |
+
|
88 |
+
def generate_sql(client, query, tables, thinking=False):
|
89 |
+
schema_info = "## Database tables\n" + "\n".join(tables)
|
90 |
+
prompt = (
|
91 |
+
f"Generate an SQL query to answer this question: \"{query}\"\n"
|
92 |
+
f"Based on this database information:\n{schema_info} {format_thinking(thinking)}"
|
93 |
+
)
|
94 |
+
system = (
|
95 |
+
"You are an expert in generating SQL queries based on a given schema. "
|
96 |
+
"You will output the generated query in <sql></sql> tags. "
|
97 |
+
"Attention: you can only run one SQL query, so if you need multiple steps, you must use subqueries."
|
98 |
+
)
|
99 |
+
content = call_llm(client, system, prompt)
|
100 |
+
return extract_tagged_content(content, "sql")
|
101 |
+
|
102 |
+
|
103 |
+
def correct_sql(client, question, query, tables, error, thinking=True):
|
104 |
+
schema_info = "## Database tables\n" + "\n".join(tables)
|
105 |
+
prompt = (
|
106 |
+
f"To answer this question: \"{question}\", I tried to run this SQL query:\n{query}\n"
|
107 |
+
f"But I got this error:\n{error}\n"
|
108 |
+
f"Please take care of the provided schema and give a correct SQL to answer the question. "
|
109 |
+
f"Output the query in <sql></sql> tags.\n{schema_info} {format_thinking(thinking)}"
|
110 |
+
)
|
111 |
+
system = (
|
112 |
+
"You are an expert in generating SQL queries based on a given schema. "
|
113 |
+
"You will output the generated query in <sql></sql> tags."
|
114 |
+
)
|
115 |
+
content = call_llm(client, system, prompt)
|
116 |
+
return extract_tagged_content(content, "sql")
|
services/utils.py
ADDED
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from urllib.parse import urlparse
|
2 |
+
|
3 |
+
def filter_tables(threshold, sorted_results):
|
4 |
+
return [doc["content"] for doc in sorted_results if doc["score"] > threshold]
|
5 |
+
|
6 |
+
def get_db_scheme_from_uri(uri: str) -> str:
|
7 |
+
"""
|
8 |
+
Given a SQLAlchemy-style connection URI, return its scheme name
|
9 |
+
(with any '+driver' suffix stripped).
|
10 |
+
|
11 |
+
Examples:
|
12 |
+
>>> get_db_scheme_from_uri("postgresql://user:pass@host/db")
|
13 |
+
'postgresql'
|
14 |
+
|
15 |
+
>>> get_db_scheme_from_uri("postgresql+psycopg2://user:pass@host/db")
|
16 |
+
'postgresql'
|
17 |
+
|
18 |
+
>>> get_db_scheme_from_uri("duckdb:///path/to/db.duckdb")
|
19 |
+
'duckdb'
|
20 |
+
"""
|
21 |
+
parsed = urlparse(uri)
|
22 |
+
scheme = parsed.scheme
|
23 |
+
if not scheme:
|
24 |
+
raise ValueError(f"No scheme found in URI: {uri!r}")
|
25 |
+
# Strip any "+driver" suffix (e.g. "mysql+mysqldb")
|
26 |
+
return scheme.split("+", 1)[0]
|
27 |
+
|
28 |
+
import os
|
29 |
+
from urllib.parse import urlparse
|
30 |
+
|
31 |
+
def extract_filename(path_or_url):
|
32 |
+
"""
|
33 |
+
Extract the file name from a local path or a URL.
|
34 |
+
|
35 |
+
Args:
|
36 |
+
path_or_url (str): The file path or URL.
|
37 |
+
|
38 |
+
Returns:
|
39 |
+
str: The extracted file name.
|
40 |
+
"""
|
41 |
+
parsed = urlparse(path_or_url)
|
42 |
+
if parsed.scheme in ('http', 'https', 'ftp'):
|
43 |
+
return os.path.basename(parsed.path)
|
44 |
+
else:
|
45 |
+
return os.path.basename(path_or_url)
|