talk_to_data / app.py
PD03's picture
Update app.py
ef15351 verified
raw
history blame
1.67 kB
# app.py
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')
# 2) Choose the publicly available TAPEX WikiSQL model
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
# 3) Set device: GPU if available, else CPU
device = 0 if torch.cuda.is_available() else -1
# 4) Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
# 5) Build the table-question-answering pipeline
table_qa = pipeline(
"table-question-answering",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
)
# 6) Define the QA function, casting all cells to strings to avoid float issues
def answer_profitability(question: str) -> str:
# Cast entire DataFrame to string
df_str = df.astype(str)
table = df_str.to_dict(orient="records")
try:
out = table_qa(table=table, query=question)
return out.get("answer", "No answer found.")
except Exception as e:
return f"Error: {e}"
# 7) Define Gradio interface
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
outputs="text",
title="SAP Profitability Q&A (TAPEX-Base)",
description=(
"Free-form questions on the synthetic profitability dataset, "
"powered end-to-end by microsoft/tapex-base-finetuned-wikisql."
)
)
# 8) Launch the app
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)