File size: 2,073 Bytes
ba2b9e4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 |
from services.ai import rank_tables,generate_sql,generate_answer , correct_sql , evaluate_difficulty
from services.utils import filter_tables
from openai import OpenAI
from database import Database
from filesource import FileSource
import os
MAX_TABLE = 3
client = OpenAI(
base_url=os.getenv("LLM_ENDPOINT"),
api_key=os.getenv("LLM_KEY")
)
def run_agent(database,prompt,give_answer=True):
retry = 5
tables = database.get_tables_array()
use_thinking = False
if len(tables) > MAX_TABLE:
print(f"using reranking because number of tables is greater than {MAX_TABLE}")
ranked = rank_tables(prompt,tables)
tables = filter_tables(0,ranked)[:MAX_TABLE]
dif = int(evaluate_difficulty(client,prompt))
if dif > 7:
print("difficulty is > 7 so we enable thinking mode")
use_thinking = True
sql = generate_sql(client,prompt,tables,use_thinking)
nb_try = 0
success = False
while nb_try < retry and not success:
nb_try = nb_try + 1
try:
print("try to launch sql request")
result = database.query(sql)
success = True
except Exception as e:
print(f"Error : {e}")
print("Try to self correct...")
error = f"{type(e).__name__} - {str(e)}"
if nb_try < retry - 2:
sql = correct_sql(client,prompt,sql,tables,error,True)
else:
sql = correct_sql(client,prompt,sql,tables,error,False)
print(sql)
if success:
print(result.to_markdown())
if give_answer:
return generate_answer(client,sql,prompt,result.to_markdown(),use_thinking)
else:
return f"Generated sql query : {sql}\n Query Result : \n {result.to_markdown()}"
# db = Database("mysql://user:password@localhost:3306/Pokemon")
# db.connect()
# file = FileSource("./Wines.csv")
# file.connect()
# print(run_agent(file,"What is the quality og the win with the less of alcohol ?"))
|