PD03 commited on
Commit
dfe31fe
·
verified ·
1 Parent(s): e2b2add

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -24
app.py CHANGED
@@ -1,12 +1,14 @@
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  from transformers import pipeline
4
 
5
- # Load & stringify your CSV
6
  df = pd.read_csv("synthetic_profit.csv")
7
  table = df.astype(str).to_dict(orient="records")
8
 
9
- # Instantiate TAPAS pipeline
10
  qa = pipeline(
11
  "table-question-answering",
12
  model="google/tapas-base-finetuned-wtq",
@@ -14,42 +16,48 @@ qa = pipeline(
14
  device=-1
15
  )
16
 
17
- # Four + one few-shot examples
 
 
 
 
 
 
 
 
 
 
 
 
18
  EXAMPLES = """
19
  Example 1:
20
  Q: What is the total revenue for Product A in EMEA in Q1 2024?
21
  A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 3075162.49
22
 
23
  Example 2:
24
- Q: What is the total cost for Product A in EMEA in Q1 2024?
25
- A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Cost → 2894321.75
26
-
27
- Example 3:
28
- Q: What is the total margin for Product A in EMEA in Q1 2024?
29
- A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum ProfitMargin → 0.18
30
-
31
- Example 4:
32
- Q: What is the average profit margin for Product A in EMEA in Q1 2024?
33
- A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then mean ProfitMargin → 0.18
34
-
35
- Example 5:
36
  Q: What is the total revenue for Product A in Q1 2024?
37
- A: Filter Product=A & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → YOUR_SUM_HERE
38
  """
39
 
40
  def answer_question(question: str) -> str:
41
- prompt = EXAMPLES + f"\nQ: {question}\nA:"
42
- out = qa(table=table, query=prompt)
43
- return out.get("answer", "No answer found.")
44
-
 
 
 
 
45
  iface = gr.Interface(
46
  fn=answer_question,
47
- inputs=gr.Textbox(lines=2, placeholder="Ask a question…"),
48
- outputs=gr.Textbox(lines=3),
49
  title="SAP Profitability Q&A",
50
- description="TAPAS few-shot sum/mean demo",
 
 
51
  allow_flagging="never",
52
  )
53
 
54
  if __name__ == "__main__":
55
- iface.launch()
 
1
+ # app.py
2
+
3
  import gradio as gr
4
  import pandas as pd
5
  from transformers import pipeline
6
 
7
+ # 1) Load & stringify your CSV
8
  df = pd.read_csv("synthetic_profit.csv")
9
  table = df.astype(str).to_dict(orient="records")
10
 
11
+ # 2) TAPAS pipeline
12
  qa = pipeline(
13
  "table-question-answering",
14
  model="google/tapas-base-finetuned-wtq",
 
16
  device=-1
17
  )
18
 
19
+ # 3) Instruction + few-shot examples
20
+ PREFIX = """
21
+ You are a table-QA assistant.
22
+ - When the question asks for “total” or “sum” of a column:
23
+ • Filter rows as specified.
24
+ • Compute the sum of that column.
25
+ • Return exactly one number (the sum).
26
+ - When the question asks for “average” or “mean”:
27
+ • Filter rows as specified.
28
+ • Compute the mean.
29
+ • Return exactly one number (the mean).
30
+ """
31
+
32
  EXAMPLES = """
33
  Example 1:
34
  Q: What is the total revenue for Product A in EMEA in Q1 2024?
35
  A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 3075162.49
36
 
37
  Example 2:
 
 
 
 
 
 
 
 
 
 
 
 
38
  Q: What is the total revenue for Product A in Q1 2024?
39
+ A: Filter Product=A & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 12032469.96
40
  """
41
 
42
  def answer_question(question: str) -> str:
43
+ prompt = PREFIX + EXAMPLES + f"\nQ: {question}\nA:"
44
+ try:
45
+ out = qa(table=table, query=prompt)
46
+ return out.get("answer", "No answer found.")
47
+ except Exception as e:
48
+ return f"❌ Pipeline error:\n{e}"
49
+
50
+ # 4) Gradio UI
51
  iface = gr.Interface(
52
  fn=answer_question,
53
+ inputs=gr.Textbox(lines=2, placeholder="e.g. What is the total revenue for Product A in EMEA in Q1 2024?"),
54
+ outputs=gr.Textbox(lines=2),
55
  title="SAP Profitability Q&A",
56
+ description=(
57
+ "Table‐QA with TAPAS: instructions + examples force a single-number sum/mean output."
58
+ ),
59
  allow_flagging="never",
60
  )
61
 
62
  if __name__ == "__main__":
63
+ iface.launch(server_name="0.0.0.0", server_port=7860)