Spaces:
Sleeping
Sleeping
File size: 1,674 Bytes
ef15351 a5ece8b aa97025 8be1581 887b999 ef15351 aa97025 a5ece8b ef15351 8be1581 ef15351 8be1581 887b999 ef15351 aa97025 6a97111 ef15351 aa97025 8be1581 60fddfe ef15351 8be1581 ef15351 67fc297 aa97025 67fc297 aa97025 887b999 ef15351 aa97025 6a97111 8be1581 ef15351 887b999 a5ece8b ef15351 aa97025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 |
# app.py
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')
# 2) Choose the publicly available TAPEX WikiSQL model
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
# 3) Set device: GPU if available, else CPU
device = 0 if torch.cuda.is_available() else -1
# 4) Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
# 5) Build the table-question-answering pipeline
table_qa = pipeline(
"table-question-answering",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
)
# 6) Define the QA function, casting all cells to strings to avoid float issues
def answer_profitability(question: str) -> str:
# Cast entire DataFrame to string
df_str = df.astype(str)
table = df_str.to_dict(orient="records")
try:
out = table_qa(table=table, query=question)
return out.get("answer", "No answer found.")
except Exception as e:
return f"Error: {e}"
# 7) Define Gradio interface
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
outputs="text",
title="SAP Profitability Q&A (TAPEX-Base)",
description=(
"Free-form questions on the synthetic profitability dataset, "
"powered end-to-end by microsoft/tapex-base-finetuned-wikisql."
)
)
# 8) Launch the app
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
|