File size: 2,401 Bytes
ef15351
 
a5ece8b
aa97025
8be1581
 
887b999
b8b6a66
aa97025
a5ece8b
68264bd
b8b6a66
 
 
 
68264bd
 
b8b6a66
 
 
 
 
 
 
 
 
 
 
 
 
 
8be1581
8cc354c
8be1581
aa97025
 
6a97111
aa97025
 
 
 
 
8be1581
60fddfe
 
b8b6a66
8be1581
68264bd
b8b6a66
 
68264bd
b8b6a66
8cc354c
b8b6a66
8cc354c
b8b6a66
 
8cc354c
67fc297
b8b6a66
aa97025
67fc297
aa97025
887b999
b8b6a66
aa97025
 
 
6a97111
b8b6a66
ef15351
b8b6a66
 
ef15351
887b999
a5ece8b
68264bd
aa97025
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
# app.py

import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')

# 2) Ensure numeric columns for true aggregation
for col in ["Revenue", "Profit", "ProfitMargin"]:
    df[col] = pd.to_numeric(df[col], errors='coerce')

# 3) Build the schema description text
#    ← replaced .iteritems() with .items() here
schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
schema_text = "Table schema:\n" + "\n".join(schema_lines)

# 4) Few-shot examples teaching SUM and AVERAGE patterns
example_block = """
Example 1
Q: Total profit by region?
A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34

Example 2
Q: Average profit margin for Product B in Americas?
A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
""".strip()

# 5) Model & pipeline setup
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device   = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

table_qa = pipeline(
    "table-question-answering",
    model=model,
    tokenizer=tokenizer,
    framework="pt",
    device=device,
)

# 6) QA function with schema-aware prompting
def answer_profitability(question: str) -> str:
    # cast all cells to string for safety
    table = df.astype(str).to_dict(orient="records")

    # assemble the full prompt
    prompt = f"""{schema_text}

{example_block}

Q: {question}
A:"""

    try:
        out = table_qa(table=table, query=prompt)
        return out.get("answer", "No answer found.")
    except Exception as e:
        return f"Error: {e}"

# 7) Gradio interface
iface = gr.Interface(
    fn=answer_profitability,
    inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
    outputs="text",
    title="SAP Profitability Q&A (Schema-Aware TAPEX)",
    description=(
        "Every query is prefixed with your table’s schema and two few-shot examples, "
        "so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code."
    )
)

# 8) Launch the app
if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)