PD03 commited on
Commit
8cc354c
·
verified ·
1 Parent(s): ef15351

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -12
app.py CHANGED
@@ -5,20 +5,18 @@ import pandas as pd
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
- # 1) Load your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
 
 
10
 
11
- # 2) Choose the publicly available TAPEX WikiSQL model
12
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
 
13
 
14
- # 3) Set device: GPU if available, else CPU
15
- device = 0 if torch.cuda.is_available() else -1
16
-
17
- # 4) Load tokenizer and model
18
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
19
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
20
 
21
- # 5) Build the table-question-answering pipeline
22
  table_qa = pipeline(
23
  "table-question-answering",
24
  model=model,
@@ -27,18 +25,32 @@ table_qa = pipeline(
27
  device=device,
28
  )
29
 
30
- # 6) Define the QA function, casting all cells to strings to avoid float issues
31
  def answer_profitability(question: str) -> str:
32
- # Cast entire DataFrame to string
 
 
 
 
 
 
 
33
  df_str = df.astype(str)
34
  table = df_str.to_dict(orient="records")
 
35
  try:
36
  out = table_qa(table=table, query=question)
37
  return out.get("answer", "No answer found.")
 
 
 
 
 
 
38
  except Exception as e:
39
  return f"Error: {e}"
40
 
41
- # 7) Define Gradio interface
42
  iface = gr.Interface(
43
  fn=answer_profitability,
44
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
@@ -46,10 +58,9 @@ iface = gr.Interface(
46
  title="SAP Profitability Q&A (TAPEX-Base)",
47
  description=(
48
  "Free-form questions on the synthetic profitability dataset, "
49
- "powered end-to-end by microsoft/tapex-base-finetuned-wikisql."
50
  )
51
  )
52
 
53
- # 8) Launch the app
54
  if __name__ == "__main__":
55
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
5
  import torch
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
+ # 1) Load and preprocess your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
10
+ # Ensure numeric Profit for aggregation
11
+ df['Profit'] = pd.to_numeric(df['Profit'], errors='coerce')
12
 
13
+ # 2) Model setup
14
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
15
+ device = 0 if torch.cuda.is_available() else -1
16
 
 
 
 
 
17
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
18
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
19
 
 
20
  table_qa = pipeline(
21
  "table-question-answering",
22
  model=model,
 
25
  device=device,
26
  )
27
 
28
+ # 3) QA function with manual fallback for region‐based aggregations
29
  def answer_profitability(question: str) -> str:
30
+ q_lower = question.lower()
31
+
32
+ # Fallback: if user asks for total profit by region, do it in pandas
33
+ if "total profit" in q_lower and "region" in q_lower:
34
+ agg = df.groupby('Region', as_index=False)['Profit'].sum()
35
+ return "\n".join(f"{row.Region}: {row.Profit}" for row in agg.itertuples())
36
+
37
+ # Otherwise, cast all cells to string and try TAPEX
38
  df_str = df.astype(str)
39
  table = df_str.to_dict(orient="records")
40
+
41
  try:
42
  out = table_qa(table=table, query=question)
43
  return out.get("answer", "No answer found.")
44
+ except IndexError as e:
45
+ # Catch the 'index out of range' and fallback to pandas for any region/group queries
46
+ if "index out of range" in str(e):
47
+ agg = df.groupby('Region', as_index=False)['Profit'].sum()
48
+ return "\n".join(f"{row.Region}: {row.Profit}" for row in agg.itertuples())
49
+ return f"Error: {e}"
50
  except Exception as e:
51
  return f"Error: {e}"
52
 
53
+ # 4) Gradio interface
54
  iface = gr.Interface(
55
  fn=answer_profitability,
56
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
 
58
  title="SAP Profitability Q&A (TAPEX-Base)",
59
  description=(
60
  "Free-form questions on the synthetic profitability dataset, "
61
+ "powered by microsoft/tapex-base-finetuned-wikisql with pandas fallbacks."
62
  )
63
  )
64
 
 
65
  if __name__ == "__main__":
66
  iface.launch(server_name="0.0.0.0", server_port=7860)