File size: 2,376 Bytes
ef15351
 
a5ece8b
aa97025
8be1581
 
887b999
e9af0b4
aa97025
a5ece8b
e9af0b4
 
 
b8b6a66
e9af0b4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8be1581
8cc354c
8be1581
aa97025
 
6a97111
e9af0b4
 
aa97025
 
 
e9af0b4
60fddfe
 
e9af0b4
8be1581
e9af0b4
 
 
 
 
 
 
 
 
 
8cc354c
67fc297
e9af0b4
 
67fc297
e9af0b4
887b999
e9af0b4
aa97025
 
e9af0b4
 
 
ef15351
e9af0b4
 
ef15351
887b999
a5ece8b
aa97025
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
# app.py

import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')

# 2) Ensure numeric types for Revenue, Profit, ProfitMargin
for col in ["Revenue", "Profit", "ProfitMargin"]:
    df[col] = pd.to_numeric(df[col], errors='coerce')

# 3) Build the schema description
schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
schema_text = "Table schema:\n" + "\n".join(schema_lines)

# 4) Few-shot examples teaching SUM and AVERAGE
few_shot = """
Example 1
Q: Total profit by region?
A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34

Example 2
Q: Average profit margin for Product B in Americas?
A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
""".strip()

# 5) Load TAPEX-WikiSQL for table-QA
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device   = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

table_qa = pipeline(
    "table-question-answering",
    model=model,
    tokenizer=tokenizer,
    framework="pt",
    device=device
)

# 6) QA function using schema-aware prompting
def answer_profitability(question: str) -> str:
    # Cast all values to strings so TAPEX can ingest them
    table = df.astype(str).to_dict(orient="records")

    # Assemble prompt with schema + examples + user question
    prompt = f"""{schema_text}

{few_shot}

Q: {question}
A:"""

    try:
        out = table_qa(table=table, query=prompt)
        return out.get("answer", "No answer found.")
    except Exception as e:
        return f"Error: {e}"

# 7) Gradio interface
iface = gr.Interface(
    fn=answer_profitability,
    inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
    outputs="text",
    title="SAP Profitability Q&A (Schema-Aware TAPEX)",
    description=(
        "Every query is prefixed with your table’s schema and two few-shot examples, "
        "so the model learns to SUM, AVERAGE, FILTER, etc., without hard-coded fallbacks."
    )
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)