PD03 commited on
Commit
25e4074
·
verified ·
1 Parent(s): 93045b6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -71
app.py CHANGED
@@ -1,89 +1,55 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import tensorflow as tf
4
-
5
- from tapas.protos import interaction_pb2
6
- from tapas.utils import number_annotation_utils, tf_example_utils, prediction_utils
7
- from tapas.scripts.run_task_main import get_classifier_model, get_task_config
8
 
9
  # 1) Load & stringify your CSV
10
  df = pd.read_csv("synthetic_profit.csv")
11
- df = df.astype(str)
12
-
13
- # 2) Build the “list of lists” table (header + rows)
14
- table = [list(df.columns)]
15
- table.extend(df.values.tolist())
16
-
17
- # 3) Prepare the TAPAS converter with aggregation candidates
18
- config = tf_example_utils.ClassifierConversionConfig(
19
- vocab_file="tapas_sqa_base/vocab.txt",
20
- max_seq_length=512,
21
- max_column_id=512,
22
- max_row_id=512,
23
- strip_column_names=False,
24
- add_aggregation_candidates=True,
25
  )
26
- converter = tf_example_utils.ToClassifierTensorflowExample(config)
27
 
28
- # 4) Load pretrained TAPAS checkpoint
29
- task_config = get_task_config(
30
- task="sqa",
31
- init_checkpoint="tapas_sqa_base/model.ckpt-0",
32
- vocab_file=config.vocab_file,
33
- bsz=1,
34
- max_seq_length=config.max_seq_length,
35
- )
36
- model, tokenizer = get_classifier_model(task_config)
37
 
38
- # 5) Build a TF example from (table, query)
39
- def make_tf_example(table, query):
40
- interaction = interaction_pb2.Interaction()
41
- # question
42
- q = interaction.questions.add()
43
- q.original_text = query
44
- # columns
45
- for col in table[0]:
46
- interaction.table.columns.add().text = col
47
- # rows
48
- for row_vals in table[1:]:
49
- row = interaction.table.rows.add()
50
- for cell in row_vals:
51
- row.cells.add().text = cell
52
- # numeric annotation for SUM/AVG
53
- number_annotation_utils.add_numeric_values(interaction)
54
- # convert to serialized Example
55
- return converter.convert(interaction)
56
 
57
- # 6) Run TAPAS & parse coordinates back to cell values
58
- def predict_answer(query):
59
- example = make_tf_example(table, query)
60
- input_fn = tf_example_utils.input_fn_builder(
61
- [example],
62
- is_training=False,
63
- drop_remainder=False,
64
- batch_size=1,
65
- seq_length=config.max_seq_length,
66
- )
67
- preds = model.predict(input_fn)
68
- coords = prediction_utils.parse_coordinates(preds[0]["answer_coordinates"])
69
- answers = [ table[r+1][c] for (r, c) in coords ] # r+1 because row 0 is header
70
- return ", ".join(answers) if answers else "No answer found."
71
 
72
- # 7) Gradio interface
73
- def answer_fn(question: str) -> str:
74
  try:
75
- return predict_answer(question)
 
76
  except Exception as e:
77
- return f"❌ Error: {e}"
78
 
 
79
  iface = gr.Interface(
80
- fn=answer_fn,
81
- inputs=gr.Textbox(lines=2, label="Your question"),
82
- outputs=gr.Textbox(label="Answer"),
83
- title="SAP Profitability Q&A (TAPAS Low-Level)",
84
  description=(
85
- "TAPAS with aggregation candidates & numeric annotations—"
86
- "robust sums/averages on your SAP data."
87
  ),
88
  allow_flagging="never",
89
  )
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from transformers import pipeline
 
 
 
 
4
 
5
  # 1) Load & stringify your CSV
6
  df = pd.read_csv("synthetic_profit.csv")
7
+ table = df.astype(str).to_dict(orient="records")
8
+
9
+ # 2) Instantiate the TAPAS pipeline from Transformers
10
+ qa = pipeline(
11
+ "table-question-answering",
12
+ model="google/tapas-base-finetuned-wtq",
13
+ tokenizer="google/tapas-base-finetuned-wtq",
14
+ device=-1, # CPU; change to 0 if you have a GPU
 
 
 
 
 
 
15
  )
 
16
 
17
+ # 3) Few-shot examples teach “filter + sum” vs. “filter + mean”
18
+ EXAMPLES = """
19
+ Example 1:
20
+ Q: What is the total revenue for Product A in EMEA in Q1 2024?
21
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 3075162.49
22
+
23
+ Example 2:
24
+ Q: What is the total cost for Product A in EMEA in Q1 2024?
25
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Cost → 2894321.75
26
 
27
+ Example 3:
28
+ Q: What is the total margin for Product A in EMEA in Q1 2024?
29
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum ProfitMargin → 0.18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
+ Example 4:
32
+ Q: What is the average profit margin for Product A in EMEA in Q1 2024?
33
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then mean ProfitMargin → 0.18
34
+ """
 
 
 
 
 
 
 
 
 
 
35
 
36
+ def answer_question(question: str) -> str:
37
+ prompt = EXAMPLES + f"\nQ: {question}\nA:"
38
  try:
39
+ result = qa(table=table, query=prompt)
40
+ return result.get("answer", "No answer found.")
41
  except Exception as e:
42
+ return f"❌ Pipeline error:\n{e}"
43
 
44
+ # 4) Gradio UI
45
  iface = gr.Interface(
46
+ fn=answer_question,
47
+ inputs=gr.Textbox(lines=2, placeholder="e.g. What is the total revenue for Product A in Q1 2024?"),
48
+ outputs=gr.Textbox(lines=3),
49
+ title="SAP Profitability Q&A",
50
  description=(
51
+ "Ask simple sum/mean questions on the synthetic SAP data. \n"
52
+ "Powered by google/tapas-base-finetuned-wtq with four few-shot examples."
53
  ),
54
  allow_flagging="never",
55
  )