Jacqkues's picture
Upload 11 files
d25ee4b verified
import re
import requests
import os
def rank_tables(query,tables):
task = "Given a user query find the usefull tables in order to build an sql request"
payload = {
"task":task,
"query": query,
"documents": tables
}
response = requests.post(os.getenv('RERANK_ENDPOINT'), json=payload)
if response.status_code == 200:
results = response.json()
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
return sorted_results
else:
raise Exception(f"Request failed: {response.status_code} - {response.text}")
def call_llm(client, system_prompt, user_prompt, model="Qwen/Qwen3-14B", temperature=0):
response = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=temperature
)
return response.choices[0].message.content
def call_llm_streaming(client, system_prompt, user_prompt,
model="Qwen/Qwen3-14B", temperature=0):
# Start streaming chat completion
stream = client.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": system_prompt},
{"role": "user", "content": user_prompt}
],
temperature=temperature,
stream=True
)
out = ""
for chunk in stream:
delta = chunk.choices[0].delta.content
if delta:
out += delta
yield delta
return out
def format_thinking(thinking):
return "" if thinking else "/no_think"
def extract_tagged_content(text, tag):
pattern = fr"<{tag}>(.*?)</{tag}>"
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
return match.group(1).strip() if match else None
def generate_answer(client, sql, query, result, thinking=False):
prompt = (
f"The user asked this question: \n\"{query}\"\n"
f"The executed SQL was: \n\"{sql}\"\n"
f"This is the result: \n{result}\n"
f"Please give an answer to the user "
f"Use an amical thone and at the end suggest followup questions related to the first user question {format_thinking(thinking)}"
)
system = "You are an expert in providing response to questions based on provided content"
return call_llm_streaming(client, system, prompt)
def evaluate_difficulty(client, query):
prompt = (
f"Output a grade between 0 and 10 on how difficult it is to generate an SQL query "
f"to answer this question:\n{query}\n/no_think"
)
system = (
"You task is to evaluate the level of difficulty for generating an sql query. "
"You will only output the difficulty level which is between 0 and 10, output in <score></score> tags"
)
content = call_llm(client, system, prompt)
return extract_tagged_content(content, "score")
def generate_sql(client, query, tables, thinking=False):
schema_info = "## Database tables\n" + "\n".join(tables)
prompt = (
f"Generate an SQL query to answer this question: \"{query}\"\n"
f"Based on this database information:\n{schema_info} {format_thinking(thinking)}"
)
system = (
"You are an expert in generating SQL queries based on a given schema. "
"You will output the generated query in <sql></sql> tags. "
"Attention: you can only run one SQL query, so if you need multiple steps, you must use subqueries."
)
content = call_llm(client, system, prompt)
return extract_tagged_content(content, "sql")
def correct_sql(client, question, query, tables, error, thinking=True):
schema_info = "## Database tables\n" + "\n".join(tables)
prompt = (
f"To answer this question: \"{question}\", I tried to run this SQL query:\n{query}\n"
f"But I got this error:\n{error}\n"
f"Please take care of the provided schema and give a correct SQL to answer the question. "
f"Output the query in <sql></sql> tags.\n{schema_info} {format_thinking(thinking)}"
)
system = (
"You are an expert in generating SQL queries based on a given schema. "
"You will output the generated query in <sql></sql> tags."
)
content = call_llm(client, system, prompt)
return extract_tagged_content(content, "sql")