talk_to_data / app.py
PD03's picture
Update app.py
68264bd verified
raw
history blame
2.4 kB
# app.py
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')
# 2) Ensure numeric columns for true aggregation
for col in ["Revenue", "Profit", "ProfitMargin"]:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 3) Build the schema description text
# ← replaced .iteritems() with .items() here
schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
schema_text = "Table schema:\n" + "\n".join(schema_lines)
# 4) Few-shot examples teaching SUM and AVERAGE patterns
example_block = """
Example 1
Q: Total profit by region?
A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
Example 2
Q: Average profit margin for Product B in Americas?
A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
""".strip()
# 5) Model & pipeline setup
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
table_qa = pipeline(
"table-question-answering",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
)
# 6) QA function with schema-aware prompting
def answer_profitability(question: str) -> str:
# cast all cells to string for safety
table = df.astype(str).to_dict(orient="records")
# assemble the full prompt
prompt = f"""{schema_text}
{example_block}
Q: {question}
A:"""
try:
out = table_qa(table=table, query=prompt)
return out.get("answer", "No answer found.")
except Exception as e:
return f"Error: {e}"
# 7) Gradio interface
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
outputs="text",
title="SAP Profitability Q&A (Schema-Aware TAPEX)",
description=(
"Every query is prefixed with your table’s schema and two few-shot examples, "
"so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code."
)
)
# 8) Launch the app
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)