File size: 2,073 Bytes
ba2b9e4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from services.ai import rank_tables,generate_sql,generate_answer , correct_sql , evaluate_difficulty
from services.utils import filter_tables
from openai import OpenAI
from database import Database
from filesource import FileSource
import os
MAX_TABLE = 3

client = OpenAI(
    base_url=os.getenv("LLM_ENDPOINT"),
    api_key=os.getenv("LLM_KEY")                           
)

def run_agent(database,prompt,give_answer=True):

    retry = 5
    tables = database.get_tables_array()

    use_thinking = False


    if len(tables) > MAX_TABLE:
        print(f"using reranking because number of tables is greater than {MAX_TABLE}")
        ranked = rank_tables(prompt,tables)
        tables = filter_tables(0,ranked)[:MAX_TABLE]
    
    dif = int(evaluate_difficulty(client,prompt))
    if dif > 7:
        print("difficulty is > 7 so we enable thinking mode")
        use_thinking = True
    sql = generate_sql(client,prompt,tables,use_thinking)
    nb_try = 0
    success = False
    while nb_try < retry and not success:
        nb_try = nb_try + 1
        try:
            print("try to launch sql request")
            result = database.query(sql)
            success = True
        except Exception as e:
            print(f"Error : {e}")
            print("Try to self correct...")
            error = f"{type(e).__name__} - {str(e)}"
            if nb_try < retry - 2:
                sql = correct_sql(client,prompt,sql,tables,error,True)
            else:
                sql = correct_sql(client,prompt,sql,tables,error,False)

            print(sql)

    if success:
        print(result.to_markdown())
        if give_answer:
            return generate_answer(client,sql,prompt,result.to_markdown(),use_thinking)
        else:
            return f"Generated sql query : {sql}\n Query Result : \n {result.to_markdown()}"

                 
# db = Database("mysql://user:password@localhost:3306/Pokemon")
# db.connect()
# file = FileSource("./Wines.csv") 
# file.connect()
# print(run_agent(file,"What is the quality og the win with the less of alcohol ?"))