PD03 commited on
Commit
519d64c
·
verified ·
1 Parent(s): e9af0b4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -43
app.py CHANGED
@@ -1,75 +1,68 @@
1
- # app.py
2
-
3
  import gradio as gr
4
  import pandas as pd
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
- # 1) Load your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
 
 
10
 
11
- # 2) Ensure numeric types for Revenue, Profit, ProfitMargin
12
- for col in ["Revenue", "Profit", "ProfitMargin"]:
13
- df[col] = pd.to_numeric(df[col], errors='coerce')
14
-
15
- # 3) Build the schema description
16
- schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
17
- schema_text = "Table schema:\n" + "\n".join(schema_lines)
18
-
19
- # 4) Few-shot examples teaching SUM and AVERAGE
20
- few_shot = """
21
- Example 1
22
- Q: Total profit by region?
23
- A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
24
 
25
- Example 2
26
- Q: Average profit margin for Product B in Americas?
27
- A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
28
- """.strip()
29
-
30
- # 5) Load TAPEX-WikiSQL for table-QA
31
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
32
  device = 0 if torch.cuda.is_available() else -1
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
35
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
36
 
37
- table_qa = pipeline(
38
- "table-question-answering",
39
  model=model,
40
  tokenizer=tokenizer,
41
  framework="pt",
42
- device=device
 
43
  )
44
 
45
- # 6) QA function using schema-aware prompting
46
  def answer_profitability(question: str) -> str:
47
- # Cast all values to strings so TAPEX can ingest them
48
- table = df.astype(str).to_dict(orient="records")
49
-
50
- # Assemble prompt with schema + examples + user question
51
- prompt = f"""{schema_text}
52
-
53
- {few_shot}
54
-
55
- Q: {question}
56
- A:"""
57
 
 
58
  try:
59
- out = table_qa(table=table, query=prompt)
60
- return out.get("answer", "No answer found.")
61
  except Exception as e:
62
- return f"Error: {e}"
 
 
 
 
 
 
 
 
 
 
63
 
64
- # 7) Gradio interface
65
  iface = gr.Interface(
66
  fn=answer_profitability,
67
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
68
  outputs="text",
69
- title="SAP Profitability Q&A (Schema-Aware TAPEX)",
70
  description=(
71
- "Every query is prefixed with your table’s schema and two few-shot examples, "
72
- "so the model learns to SUM, AVERAGE, FILTER, etc., without hard-coded fallbacks."
73
  )
74
  )
75
 
 
 
 
1
  import gradio as gr
2
  import pandas as pd
3
  import torch
4
+ import duckdb
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
+ # 1) Load data into DuckDB
8
  df = pd.read_csv('synthetic_profit.csv')
9
+ con = duckdb.connect(':memory:')
10
+ con.register('sap', df)
11
 
12
+ # 2) Build a one-line schema for prompting
13
+ schema = ", ".join(df.columns) # e.g. "Region,Product,FiscalYear,...."
 
 
 
 
 
 
 
 
 
 
 
14
 
15
+ # 3) Load TAPEX-WikiSQL as a text2text generator
 
 
 
 
 
16
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
17
  device = 0 if torch.cuda.is_available() else -1
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
21
 
22
+ sql_gen = pipeline(
23
+ "text2text-generation",
24
  model=model,
25
  tokenizer=tokenizer,
26
  framework="pt",
27
+ device=device,
28
+ max_length=128,
29
  )
30
 
31
+ # 4) Core QA fn: NL SQL → execute → return result
32
  def answer_profitability(question: str) -> str:
33
+ # a) Prompt TAPEX to generate SQL
34
+ prompt = (
35
+ f"-- Translate to SQL over table `sap` with columns ({schema})\n"
36
+ f"Question: {question}\n"
37
+ "SQL:"
38
+ )
39
+ sql = sql_gen(prompt)[0]['generated_text'].strip()
 
 
 
40
 
41
+ # b) Execute the generated SQL
42
  try:
43
+ result_df = con.execute(sql).df()
 
44
  except Exception as e:
45
+ return f"❌ SQL Error: {e}\n\nGenerated SQL:\n{sql}"
46
+
47
+ # c) Format the output
48
+ if result_df.empty:
49
+ return f"No rows returned.\n\nGenerated SQL:\n{sql}"
50
+
51
+ # If it's a single cell result, just return that number
52
+ if result_df.shape == (1,1):
53
+ return str(result_df.iat[0,0])
54
+ # Otherwise pretty-print the DataFrame
55
+ return result_df.to_string(index=False)
56
 
57
+ # 5) Gradio UI
58
  iface = gr.Interface(
59
  fn=answer_profitability,
60
  inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
61
  outputs="text",
62
+ title="SAP Profitability Q&A (SQL-Generation)",
63
  description=(
64
+ "TAPEX converts your natural-language query into SQL,\n"
65
+ "then runs it via DuckDB—no hard-coded fallbacks."
66
  )
67
  )
68