PD03 commited on
Commit
8be1581
·
verified ·
1 Parent(s): 1373b6a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +14 -14
app.py CHANGED
@@ -1,27 +1,30 @@
1
- # app.py
2
  import gradio as gr
3
- from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
4
  import pandas as pd
 
 
5
 
6
- # Load your synthetic profitability dataset
7
  df = pd.read_csv('synthetic_profit.csv')
8
 
9
- # Initialize the TAPEX small model fine-tuned on WikiSQL
10
- MODEL_ID = "microsoft/tapex-small-finetuned-wikisql"
 
 
 
11
 
12
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
13
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
14
 
15
- # Build a table-QA pipeline
16
  table_qa = pipeline(
17
  "table-question-answering",
18
  model=model,
19
  tokenizer=tokenizer,
20
  framework="pt",
21
- device=-1 # set to 0 if you enable GPU in your Space
22
  )
23
 
24
- def answer_profitability(question):
 
25
  table = df.to_dict(orient="records")
26
  try:
27
  out = table_qa(table=table, query=question)
@@ -29,16 +32,13 @@ def answer_profitability(question):
29
  except Exception as e:
30
  return f"Error: {e}"
31
 
32
- # Gradio interface
33
  iface = gr.Interface(
34
  fn=answer_profitability,
35
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
36
  outputs="text",
37
- title="SAP Profitability Q&A (TAPEX-Small)",
38
- description="""
39
- Ask free-form questions on the synthetic profitability dataset.
40
- Powered end-to-end by microsoft/tapex-small-finetuned-wikisql.
41
- """
42
  )
43
 
44
  if __name__ == "__main__":
 
 
1
  import gradio as gr
 
2
  import pandas as pd
3
+ import torch
4
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
 
6
+ # 1) Load data
7
  df = pd.read_csv('synthetic_profit.csv')
8
 
9
+ # 2) Use the publicly available TAPEX base WikiSQL model
10
+ MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
11
+
12
+ # 3) Ensure backend is available
13
+ device = 0 if torch.cuda.is_available() else -1
14
 
15
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
16
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
17
 
 
18
  table_qa = pipeline(
19
  "table-question-answering",
20
  model=model,
21
  tokenizer=tokenizer,
22
  framework="pt",
23
+ device=device,
24
  )
25
 
26
+ # 4) QA function
27
+ def answer_profitability(question: str) -> str:
28
  table = df.to_dict(orient="records")
29
  try:
30
  out = table_qa(table=table, query=question)
 
32
  except Exception as e:
33
  return f"Error: {e}"
34
 
35
+ # 5) Gradio UI
36
  iface = gr.Interface(
37
  fn=answer_profitability,
38
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
39
  outputs="text",
40
+ title="SAP Profitability Q&A (TAPEX-Base)",
41
+ description="Free-form questions on synthetic profitability data using microsoft/tapex-base-finetuned-wikisql."
 
 
 
42
  )
43
 
44
  if __name__ == "__main__":