talk_to_data / app.py
PD03's picture
Update app.py
996ed5a verified
raw
history blame
2.14 kB
# app.py
import gradio as gr
import pandas as pd
import torch
import duckdb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load data and register it in DuckDB
df = pd.read_csv('synthetic_profit.csv')
conn = duckdb.connect(database=':memory:')
conn.register('sap', df)
# 2) Build a one-line schema description
schema = ", ".join(df.columns) # e.g. "Region, Product, FiscalYear, ..."
# 3) Load TAPEX (WikiSQL) for SQL generation
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
sql_generator = pipeline(
"text2text-generation",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
# limit length so it doesn’t try to output the entire table!
max_length=128,
)
# 4) Your new QA function
def answer_profitability(question: str) -> str:
# 4a) Prompt the model to generate SQL
prompt = (
f"Translate to SQL for table `sap` with columns ({schema}):\n"
f"Question: {question}\n"
"SQL:"
)
sql = sql_generator(prompt)[0]['generated_text'].strip()
# 4b) Execute the generated SQL and return results
try:
result_df = conn.execute(sql).df()
# pretty-print as text
if result_df.empty:
return f"No rows returned. Generated SQL was:\n{sql}"
return result_df.to_string(index=False)
except Exception as e:
# if something goes wrong, show you the SQL so you can debug
return f"Error executing SQL: {e}\n\nGenerated SQL:\n{sql}"
# 5) Gradio interface
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask about your SAP data…"),
outputs="textbox",
title="SAP Profitability Q&A (SQL-Generation)",
description=(
"Uses TAPEX to translate your natural-language question "
"into a SQL query over the `sap` table, then runs it via DuckDB."
)
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)