PD03 commited on
Commit
5edc373
·
verified ·
1 Parent(s): dfe31fe

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +74 -34
app.py CHANGED
@@ -1,60 +1,100 @@
1
  # app.py
2
 
 
3
  import gradio as gr
4
  import pandas as pd
5
  from transformers import pipeline
6
 
7
- # 1) Load & stringify your CSV
8
  df = pd.read_csv("synthetic_profit.csv")
9
- table = df.astype(str).to_dict(orient="records")
10
 
11
- # 2) TAPAS pipeline
12
- qa = pipeline(
13
  "table-question-answering",
14
  model="google/tapas-base-finetuned-wtq",
15
  tokenizer="google/tapas-base-finetuned-wtq",
16
  device=-1
17
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
- # 3) Instruction + few-shot examples
20
- PREFIX = """
21
- You are a table-QA assistant.
22
- - When the question asks for “total” or “sum” of a column:
23
- • Filter rows as specified.
24
- • Compute the sum of that column.
25
- • Return exactly one number (the sum).
26
- - When the question asks for “average” or “mean”:
27
- • Filter rows as specified.
28
- • Compute the mean.
29
- • Return exactly one number (the mean).
30
- """
31
-
32
- EXAMPLES = """
33
- Example 1:
34
- Q: What is the total revenue for Product A in EMEA in Q1 2024?
35
- A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 3075162.49
36
-
37
- Example 2:
38
- Q: What is the total revenue for Product A in Q1 2024?
39
- A: Filter Product=A & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 12032469.96
40
- """
41
-
42
- def answer_question(question: str) -> str:
43
- prompt = PREFIX + EXAMPLES + f"\nQ: {question}\nA:"
44
  try:
45
- out = qa(table=table, query=prompt)
46
- return out.get("answer", "No answer found.")
47
  except Exception as e:
48
  return f"❌ Pipeline error:\n{e}"
49
 
50
  # 4) Gradio UI
51
  iface = gr.Interface(
52
- fn=answer_question,
53
- inputs=gr.Textbox(lines=2, placeholder="e.g. What is the total revenue for Product A in EMEA in Q1 2024?"),
54
  outputs=gr.Textbox(lines=2),
55
  title="SAP Profitability Q&A",
56
  description=(
57
- "Table‐QA with TAPAS: instructions + examples force a single-number sum/mean output."
 
58
  ),
59
  allow_flagging="never",
60
  )
 
1
  # app.py
2
 
3
+ import re
4
  import gradio as gr
5
  import pandas as pd
6
  from transformers import pipeline
7
 
8
+ # 1) Load your synthetic SAP data
9
  df = pd.read_csv("synthetic_profit.csv")
 
10
 
11
+ # 2) Prepare TAPAS as a fallback (optional)
12
+ tapas = pipeline(
13
  "table-question-answering",
14
  model="google/tapas-base-finetuned-wtq",
15
  tokenizer="google/tapas-base-finetuned-wtq",
16
  device=-1
17
  )
18
+ table = df.astype(str).to_dict(orient="records")
19
+
20
+ # 3) Mapping words → pandas methods and columns
21
+ OPERATIONS = {
22
+ "total": "sum",
23
+ "sum": "sum",
24
+ "average": "mean",
25
+ "mean": "mean"
26
+ }
27
+ COLUMNS = {
28
+ "revenue": "Revenue",
29
+ "cost": "Cost",
30
+ "profit margin": "ProfitMargin",
31
+ "profit": "Profit",
32
+ "margin": "ProfitMargin"
33
+ }
34
+
35
+ def parse_and_compute(question: str) -> str | None:
36
+ q = question.lower()
37
+
38
+ # 1) What operation?
39
+ op = next((OPERATIONS[k] for k in OPERATIONS if k in q), None)
40
+ # 2) Which column?
41
+ col = next((COLUMNS[k] for k in COLUMNS if k in q), None)
42
+ # 3) Which product?
43
+ prod = next((p for p in df["Product"].unique() if p.lower() in q), None)
44
+ # 4) Which region? (optional)
45
+ region = next((r for r in df["Region"].unique() if r.lower() in q), None)
46
+ # 5) Which year?
47
+ m_y = re.search(r"\b(20\d{2})\b", q)
48
+ year = int(m_y.group(1)) if m_y else None
49
+ # 6) Which quarter?
50
+ qtr = next((fq for fq in df["FiscalQuarter"].unique() if fq.lower() in q), None)
51
+
52
+ # Must have at least: op, col, prod, year, qtr
53
+ if None in (op, col, prod, year, qtr):
54
+ return None
55
+
56
+ # Build the mask
57
+ mask = (
58
+ (df["Product"] == prod) &
59
+ (df["FiscalYear"] == year) &
60
+ (df["FiscalQuarter"] == qtr)
61
+ )
62
+ if region:
63
+ mask &= (df["Region"] == region)
64
+
65
+ # Compute
66
+ try:
67
+ series = df.loc[mask, col]
68
+ result = getattr(series, op)()
69
+ except Exception:
70
+ return None
71
+
72
+ # Friendly formatting
73
+ region_part = f" in {region}" if region else ""
74
+ return f"{op.capitalize()} {col} for {prod}{region_part}, {qtr} {year}: {result:.2f}"
75
+
76
+ def answer(question: str) -> str:
77
+ # 1) Try the generic parser + Pandas
78
+ out = parse_and_compute(question)
79
+ if out is not None:
80
+ return out
81
 
82
+ # 2) Fallback to TAPAS for anything else
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83
  try:
84
+ res = tapas(table=table, query=question)
85
+ return res.get("answer", "No answer found.")
86
  except Exception as e:
87
  return f"❌ Pipeline error:\n{e}"
88
 
89
  # 4) Gradio UI
90
  iface = gr.Interface(
91
+ fn=answer,
92
+ inputs=gr.Textbox(lines=2, placeholder="e.g. What is the total revenue for Product A in Q1 2024?"),
93
  outputs=gr.Textbox(lines=2),
94
  title="SAP Profitability Q&A",
95
  description=(
96
+ "Generic sum/mean parsing via Pandas (region optional), "
97
+ "falling back to TAPAS only if the question doesn't match."
98
  ),
99
  allow_flagging="never",
100
  )