|
import re |
|
import requests |
|
import os |
|
def rank_tables(query,tables): |
|
task = "Given a user query find the usefull tables in order to build an sql request" |
|
payload = { |
|
"task":task, |
|
"query": query, |
|
"documents": tables |
|
} |
|
|
|
response = requests.post(os.getenv('RERANK_ENDPOINT'), json=payload) |
|
|
|
if response.status_code == 200: |
|
results = response.json() |
|
sorted_results = sorted(results, key=lambda x: x["score"], reverse=True) |
|
return sorted_results |
|
else: |
|
raise Exception(f"Request failed: {response.status_code} - {response.text}") |
|
|
|
def call_llm(client, system_prompt, user_prompt, model="Qwen/Qwen3-14B", temperature=0): |
|
response = client.chat.completions.create( |
|
model=model, |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt} |
|
], |
|
temperature=temperature |
|
) |
|
return response.choices[0].message.content |
|
|
|
def call_llm_streaming(client, system_prompt, user_prompt, |
|
model="Qwen/Qwen3-14B", temperature=0): |
|
|
|
stream = client.chat.completions.create( |
|
model=model, |
|
messages=[ |
|
{"role": "system", "content": system_prompt}, |
|
{"role": "user", "content": user_prompt} |
|
], |
|
temperature=temperature, |
|
stream=True |
|
) |
|
out = "" |
|
for chunk in stream: |
|
delta = chunk.choices[0].delta.content |
|
if delta: |
|
out += delta |
|
yield delta |
|
return out |
|
|
|
|
|
def format_thinking(thinking): |
|
return "" if thinking else "/no_think" |
|
|
|
|
|
def extract_tagged_content(text, tag): |
|
pattern = fr"<{tag}>(.*?)</{tag}>" |
|
match = re.search(pattern, text, re.DOTALL | re.IGNORECASE) |
|
return match.group(1).strip() if match else None |
|
|
|
|
|
def generate_answer(client, sql, query, result, thinking=False): |
|
prompt = ( |
|
f"The user asked this question: \n\"{query}\"\n" |
|
f"The executed SQL was: \n\"{sql}\"\n" |
|
f"This is the result: \n{result}\n" |
|
f"Please give an answer to the user " |
|
f"Use an amical thone and at the end suggest followup questions related to the first user question {format_thinking(thinking)}" |
|
) |
|
system = "You are an expert in providing response to questions based on provided content" |
|
return call_llm_streaming(client, system, prompt) |
|
|
|
|
|
def evaluate_difficulty(client, query): |
|
prompt = ( |
|
f"Output a grade between 0 and 10 on how difficult it is to generate an SQL query " |
|
f"to answer this question:\n{query}\n/no_think" |
|
) |
|
system = ( |
|
"You task is to evaluate the level of difficulty for generating an sql query. " |
|
"You will only output the difficulty level which is between 0 and 10, output in <score></score> tags" |
|
) |
|
content = call_llm(client, system, prompt) |
|
return extract_tagged_content(content, "score") |
|
|
|
|
|
def generate_sql(client, query, tables, thinking=False): |
|
schema_info = "## Database tables\n" + "\n".join(tables) |
|
prompt = ( |
|
f"Generate an SQL query to answer this question: \"{query}\"\n" |
|
f"Based on this database information:\n{schema_info} {format_thinking(thinking)}" |
|
) |
|
system = ( |
|
"You are an expert in generating SQL queries based on a given schema. " |
|
"You will output the generated query in <sql></sql> tags. " |
|
"Attention: you can only run one SQL query, so if you need multiple steps, you must use subqueries." |
|
) |
|
content = call_llm(client, system, prompt) |
|
return extract_tagged_content(content, "sql") |
|
|
|
|
|
def correct_sql(client, question, query, tables, error, thinking=True): |
|
schema_info = "## Database tables\n" + "\n".join(tables) |
|
prompt = ( |
|
f"To answer this question: \"{question}\", I tried to run this SQL query:\n{query}\n" |
|
f"But I got this error:\n{error}\n" |
|
f"Please take care of the provided schema and give a correct SQL to answer the question. " |
|
f"Output the query in <sql></sql> tags.\n{schema_info} {format_thinking(thinking)}" |
|
) |
|
system = ( |
|
"You are an expert in generating SQL queries based on a given schema. " |
|
"You will output the generated query in <sql></sql> tags." |
|
) |
|
content = call_llm(client, system, prompt) |
|
return extract_tagged_content(content, "sql") |
|
|