File size: 4,376 Bytes
d25ee4b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
import re
import requests
import os
def rank_tables(query,tables):
    task = "Given a user query find the usefull tables in order to build an sql request"
    payload = {
        "task":task,
        "query": query,
        "documents": tables
    }

    response = requests.post(os.getenv('RERANK_ENDPOINT'), json=payload)

    if response.status_code == 200:
        results = response.json()
        sorted_results = sorted(results, key=lambda x: x["score"], reverse=True)
        return sorted_results
    else:
        raise Exception(f"Request failed: {response.status_code} - {response.text}")

def call_llm(client, system_prompt, user_prompt, model="Qwen/Qwen3-14B", temperature=0):
    response = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=temperature
    )
    return response.choices[0].message.content

def call_llm_streaming(client, system_prompt, user_prompt,
                       model="Qwen/Qwen3-14B", temperature=0):
    # Start streaming chat completion
    stream = client.chat.completions.create(
        model=model,
        messages=[
            {"role": "system", "content": system_prompt},
            {"role": "user", "content": user_prompt}
        ],
        temperature=temperature,
        stream=True
    )
    out = ""
    for chunk in stream:
        delta = chunk.choices[0].delta.content
        if delta:
            out += delta
            yield delta
    return out


def format_thinking(thinking):
    return "" if thinking else "/no_think"


def extract_tagged_content(text, tag):
    pattern = fr"<{tag}>(.*?)</{tag}>"
    match = re.search(pattern, text, re.DOTALL | re.IGNORECASE)
    return match.group(1).strip() if match else None


def generate_answer(client, sql, query, result, thinking=False):
    prompt = (
        f"The user asked this question: \n\"{query}\"\n"
        f"The executed SQL was: \n\"{sql}\"\n"
        f"This is the result: \n{result}\n"
        f"Please give an answer to the user "
        f"Use an amical thone and at the end suggest followup questions related to the first user question {format_thinking(thinking)}"
    )
    system = "You are an expert in providing response to questions based on provided content"
    return call_llm_streaming(client, system, prompt)


def evaluate_difficulty(client, query):
    prompt = (
        f"Output a grade between 0 and 10 on how difficult it is to generate an SQL query "
        f"to answer this question:\n{query}\n/no_think"
    )
    system = (
        "You task is to evaluate the level of difficulty for generating an sql query. "
        "You will only output the difficulty level which is between 0 and 10, output in <score></score> tags"
    )
    content = call_llm(client, system, prompt)
    return extract_tagged_content(content, "score")


def generate_sql(client, query, tables, thinking=False):
    schema_info = "## Database tables\n" + "\n".join(tables)
    prompt = (
        f"Generate an SQL query to answer this question: \"{query}\"\n"
        f"Based on this database information:\n{schema_info} {format_thinking(thinking)}"
    )
    system = (
        "You are an expert in generating SQL queries based on a given schema. "
        "You will output the generated query in <sql></sql> tags. "
        "Attention: you can only run one SQL query, so if you need multiple steps, you must use subqueries."
    )
    content = call_llm(client, system, prompt)
    return extract_tagged_content(content, "sql")


def correct_sql(client, question, query, tables, error, thinking=True):
    schema_info = "## Database tables\n" + "\n".join(tables)
    prompt = (
        f"To answer this question: \"{question}\", I tried to run this SQL query:\n{query}\n"
        f"But I got this error:\n{error}\n"
        f"Please take care of the provided schema and give a correct SQL to answer the question. "
        f"Output the query in <sql></sql> tags.\n{schema_info} {format_thinking(thinking)}"
    )
    system = (
        "You are an expert in generating SQL queries based on a given schema. "
        "You will output the generated query in <sql></sql> tags."
    )
    content = call_llm(client, system, prompt)
    return extract_tagged_content(content, "sql")