Spaces:
Sleeping
Sleeping
File size: 2,376 Bytes
ef15351 a5ece8b aa97025 8be1581 887b999 e9af0b4 aa97025 a5ece8b e9af0b4 b8b6a66 e9af0b4 8be1581 8cc354c 8be1581 aa97025 6a97111 e9af0b4 aa97025 e9af0b4 60fddfe e9af0b4 8be1581 e9af0b4 8cc354c 67fc297 e9af0b4 67fc297 e9af0b4 887b999 e9af0b4 aa97025 e9af0b4 ef15351 e9af0b4 ef15351 887b999 a5ece8b aa97025 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 |
# app.py
import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')
# 2) Ensure numeric types for Revenue, Profit, ProfitMargin
for col in ["Revenue", "Profit", "ProfitMargin"]:
df[col] = pd.to_numeric(df[col], errors='coerce')
# 3) Build the schema description
schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
schema_text = "Table schema:\n" + "\n".join(schema_lines)
# 4) Few-shot examples teaching SUM and AVERAGE
few_shot = """
Example 1
Q: Total profit by region?
A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
Example 2
Q: Average profit margin for Product B in Americas?
A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
""".strip()
# 5) Load TAPEX-WikiSQL for table-QA
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
table_qa = pipeline(
"table-question-answering",
model=model,
tokenizer=tokenizer,
framework="pt",
device=device
)
# 6) QA function using schema-aware prompting
def answer_profitability(question: str) -> str:
# Cast all values to strings so TAPEX can ingest them
table = df.astype(str).to_dict(orient="records")
# Assemble prompt with schema + examples + user question
prompt = f"""{schema_text}
{few_shot}
Q: {question}
A:"""
try:
out = table_qa(table=table, query=prompt)
return out.get("answer", "No answer found.")
except Exception as e:
return f"Error: {e}"
# 7) Gradio interface
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
outputs="text",
title="SAP Profitability Q&A (Schema-Aware TAPEX)",
description=(
"Every query is prefixed with your table’s schema and two few-shot examples, "
"so the model learns to SUM, AVERAGE, FILTER, etc., without hard-coded fallbacks."
)
)
if __name__ == "__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)
|