talk_to_data / app.py
PD03's picture
Update app.py
79c9d08 verified
raw
history blame
1.59 kB
import gradio as gr
import pandas as pd
import torch
import duckdb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# Load data into DuckDB
df = pd.read_csv('synthetic_profit.csv')
con = duckdb.connect(':memory:')
con.register('sap', df)
# One-line schema for prompts
schema = ", ".join(df.columns)
# Load TAPEX for SQL generation
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
sql_gen = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
max_length=128,
)
def answer_profitability(question: str) -> str:
# 1) Generate SQL
prompt = (
f"-- Translate to SQL for table `sap` ({schema})\n"
f"Question: {question}\n"
"SQL:"
)
sql = sql_gen(prompt)[0]['generated_text'].strip()
# 2) Try to execute it
try:
df_out = con.execute(sql).df()
except Exception as e:
# Use a normal f-string with explicit \n for newlines
return (
f"❌ **SQL Error**\n"
f"```\n{e}\n```\n\n"
f"**Generated SQL**\n"
f"```sql\n{sql}\n```"
)
# 3) Format successful result
if df_out.empty:
return (
"No rows returned.\n\n"
f"**Generated SQL**\n```sql\n{sql}\n```"
)
if df_out.shape == (1,1):
return str(df_out.iat[0,0])
return df_out.to_markdown(index=False)