Spaces:
Sleeping
Sleeping
File size: 2,401 Bytes
ef15351 a5ece8b aa97025 8be1581 887b999 b8b6a66 aa97025 a5ece8b 68264bd b8b6a66 68264bd b8b6a66 8be1581 8cc354c 8be1581 aa97025 6a97111 aa97025 8be1581 60fddfe b8b6a66 8be1581 68264bd b8b6a66 68264bd b8b6a66 8cc354c b8b6a66 8cc354c b8b6a66 8cc354c 67fc297 b8b6a66 aa97025 67fc297 aa97025 887b999 b8b6a66 aa97025 6a97111 b8b6a66 ef15351 b8b6a66 ef15351 887b999 a5ece8b 68264bd aa97025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 |
# app.py
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')
# 2) Ensure numeric columns for true aggregation
for col in ["Revenue", "Profit", "ProfitMargin"]:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 3) Build the schema description text
# ← replaced .iteritems() with .items() here
schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
schema_text = "Table schema:\n" + "\n".join(schema_lines)
# 4) Few-shot examples teaching SUM and AVERAGE patterns
example_block = """
Example 1
Q: Total profit by region?
A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
Example 2
Q: Average profit margin for Product B in Americas?
A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
""".strip()
# 5) Model & pipeline setup
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
table_qa = pipeline(
"table-question-answering",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device,
)
# 6) QA function with schema-aware prompting
def answer_profitability(question: str) -> str:
# cast all cells to string for safety
table = df.astype(str).to_dict(orient="records")
# assemble the full prompt
prompt = f"""{schema_text}
{example_block}
Q: {question}
A:"""
try:
out = table_qa(table=table, query=prompt)
return out.get("answer", "No answer found.")
except Exception as e:
return f"Error: {e}"
# 7) Gradio interface
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
outputs="text",
title="SAP Profitability Q&A (Schema-Aware TAPEX)",
description=(
"Every query is prefixed with your table’s schema and two few-shot examples, "
"so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code."
)
)
# 8) Launch the app
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
|