PD03 commited on
Commit
e9af0b4
·
verified ·
1 Parent(s): 1b29e74

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +43 -34
app.py CHANGED
@@ -3,64 +3,73 @@
3
  import gradio as gr
4
  import pandas as pd
5
  import torch
6
- import duckdb
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
 
9
- # 1) Load data and register it in DuckDB
10
  df = pd.read_csv('synthetic_profit.csv')
11
- conn = duckdb.connect(database=':memory:')
12
- conn.register('sap', df)
13
 
14
- # 2) Build a one-line schema description
15
- schema = ", ".join(df.columns) # e.g. "Region, Product, FiscalYear, ..."
 
16
 
17
- # 3) Load TAPEX (WikiSQL) for SQL generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
23
 
24
- sql_generator = pipeline(
25
- "text2text-generation",
26
  model=model,
27
  tokenizer=tokenizer,
28
  framework="pt",
29
- device=device,
30
- # limit length so it doesn’t try to output the entire table!
31
- max_length=128,
32
  )
33
 
34
- # 4) Your new QA function
35
  def answer_profitability(question: str) -> str:
36
- # 4a) Prompt the model to generate SQL
37
- prompt = (
38
- f"Translate to SQL for table `sap` with columns ({schema}):\n"
39
- f"Question: {question}\n"
40
- "SQL:"
41
- )
42
- sql = sql_generator(prompt)[0]['generated_text'].strip()
 
 
 
43
 
44
- # 4b) Execute the generated SQL and return results
45
  try:
46
- result_df = conn.execute(sql).df()
47
- # pretty-print as text
48
- if result_df.empty:
49
- return f"No rows returned. Generated SQL was:\n{sql}"
50
- return result_df.to_string(index=False)
51
  except Exception as e:
52
- # if something goes wrong, show you the SQL so you can debug
53
- return f"Error executing SQL: {e}\n\nGenerated SQL:\n{sql}"
54
 
55
- # 5) Gradio interface
56
  iface = gr.Interface(
57
  fn=answer_profitability,
58
- inputs=gr.Textbox(lines=2, placeholder="Ask about your SAP data…"),
59
- outputs="textbox",
60
- title="SAP Profitability Q&A (SQL-Generation)",
61
  description=(
62
- "Uses TAPEX to translate your natural-language question "
63
- "into a SQL query over the `sap` table, then runs it via DuckDB."
64
  )
65
  )
66
 
 
3
  import gradio as gr
4
  import pandas as pd
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
+ # 1) Load your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
 
 
10
 
11
+ # 2) Ensure numeric types for Revenue, Profit, ProfitMargin
12
+ for col in ["Revenue", "Profit", "ProfitMargin"]:
13
+ df[col] = pd.to_numeric(df[col], errors='coerce')
14
 
15
+ # 3) Build the schema description
16
+ schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
17
+ schema_text = "Table schema:\n" + "\n".join(schema_lines)
18
+
19
+ # 4) Few-shot examples teaching SUM and AVERAGE
20
+ few_shot = """
21
+ Example 1
22
+ Q: Total profit by region?
23
+ A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
24
+
25
+ Example 2
26
+ Q: Average profit margin for Product B in Americas?
27
+ A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
28
+ """.strip()
29
+
30
+ # 5) Load TAPEX-WikiSQL for table-QA
31
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
32
  device = 0 if torch.cuda.is_available() else -1
33
 
34
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
35
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
36
 
37
+ table_qa = pipeline(
38
+ "table-question-answering",
39
  model=model,
40
  tokenizer=tokenizer,
41
  framework="pt",
42
+ device=device
 
 
43
  )
44
 
45
+ # 6) QA function using schema-aware prompting
46
  def answer_profitability(question: str) -> str:
47
+ # Cast all values to strings so TAPEX can ingest them
48
+ table = df.astype(str).to_dict(orient="records")
49
+
50
+ # Assemble prompt with schema + examples + user question
51
+ prompt = f"""{schema_text}
52
+
53
+ {few_shot}
54
+
55
+ Q: {question}
56
+ A:"""
57
 
 
58
  try:
59
+ out = table_qa(table=table, query=prompt)
60
+ return out.get("answer", "No answer found.")
 
 
 
61
  except Exception as e:
62
+ return f"Error: {e}"
 
63
 
64
+ # 7) Gradio interface
65
  iface = gr.Interface(
66
  fn=answer_profitability,
67
+ inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
68
+ outputs="text",
69
+ title="SAP Profitability Q&A (Schema-Aware TAPEX)",
70
  description=(
71
+ "Every query is prefixed with your table’s schema and two few-shot examples, "
72
+ "so the model learns to SUM, AVERAGE, FILTER, etc., without hard-coded fallbacks."
73
  )
74
  )
75