PD03 commited on
Commit
b8b6a66
·
verified ·
1 Parent(s): 8cc354c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -24
app.py CHANGED
@@ -5,12 +5,29 @@ import pandas as pd
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
- # 1) Load and preprocess your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
10
- # Ensure numeric Profit for aggregation
11
- df['Profit'] = pd.to_numeric(df['Profit'], errors='coerce')
12
 
13
- # 2) Model setup
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
15
  device = 0 if torch.cuda.is_available() else -1
16
 
@@ -25,40 +42,35 @@ table_qa = pipeline(
25
  device=device,
26
  )
27
 
28
- # 3) QA function with manual fallback for region‐based aggregations
29
  def answer_profitability(question: str) -> str:
30
- q_lower = question.lower()
 
 
 
 
31
 
32
- # Fallback: if user asks for total profit by region, do it in pandas
33
- if "total profit" in q_lower and "region" in q_lower:
34
- agg = df.groupby('Region', as_index=False)['Profit'].sum()
35
- return "\n".join(f"{row.Region}: {row.Profit}" for row in agg.itertuples())
36
 
37
- # Otherwise, cast all cells to string and try TAPEX
38
- df_str = df.astype(str)
39
- table = df_str.to_dict(orient="records")
40
 
 
41
  try:
42
- out = table_qa(table=table, query=question)
43
  return out.get("answer", "No answer found.")
44
- except IndexError as e:
45
- # Catch the 'index out of range' and fallback to pandas for any region/group queries
46
- if "index out of range" in str(e):
47
- agg = df.groupby('Region', as_index=False)['Profit'].sum()
48
- return "\n".join(f"{row.Region}: {row.Profit}" for row in agg.itertuples())
49
- return f"Error: {e}"
50
  except Exception as e:
51
  return f"Error: {e}"
52
 
53
- # 4) Gradio interface
54
  iface = gr.Interface(
55
  fn=answer_profitability,
56
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
57
  outputs="text",
58
- title="SAP Profitability Q&A (TAPEX-Base)",
59
  description=(
60
- "Free-form questions on the synthetic profitability dataset, "
61
- "powered by microsoft/tapex-base-finetuned-wikisql with pandas fallbacks."
62
  )
63
  )
64
 
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
+ # 1) Load your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
 
 
10
 
11
+ # 2) Ensure numeric columns for true aggregation (optional, but helps you verify sums)
12
+ for col in ["Revenue", "Profit", "ProfitMargin"]:
13
+ df[col] = pd.to_numeric(df[col], errors='coerce')
14
+
15
+ # 3) Build the schema description text
16
+ schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.iteritems()]
17
+ schema_text = "Table schema:\n" + "\n".join(schema_lines)
18
+
19
+ # 4) Few-shot examples teaching SUM and AVERAGE patterns
20
+ example_block = """
21
+ Example 1
22
+ Q: Total profit by region?
23
+ A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
24
+
25
+ Example 2
26
+ Q: Average profit margin for Product B in Americas?
27
+ A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
28
+ """.strip()
29
+
30
+ # 5) Model & pipeline setup
31
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
32
  device = 0 if torch.cuda.is_available() else -1
33
 
 
42
  device=device,
43
  )
44
 
45
+ # 6) QA function with schema-aware prompting
46
  def answer_profitability(question: str) -> str:
47
+ # 6a) cast all cells to string for safety
48
+ table = df.astype(str).to_dict(orient="records")
49
+
50
+ # 6b) assemble the full prompt
51
+ prompt = f"""{schema_text}
52
 
53
+ {example_block}
 
 
 
54
 
55
+ Q: {question}
56
+ A:"""
 
57
 
58
+ # 6c) call TAPEX
59
  try:
60
+ out = table_qa(table=table, query=prompt)
61
  return out.get("answer", "No answer found.")
 
 
 
 
 
 
62
  except Exception as e:
63
  return f"Error: {e}"
64
 
65
+ # 7) Gradio interface
66
  iface = gr.Interface(
67
  fn=answer_profitability,
68
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
69
  outputs="text",
70
+ title="SAP Profitability Q&A (Schema-Aware TAPEX)",
71
  description=(
72
+ "Every query is prefixed with your table’s schema and two few-shot examples, "
73
+ "so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code."
74
  )
75
  )
76