PD03 commited on
Commit
996ed5a
·
verified ·
1 Parent(s): 68264bd

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +33 -44
app.py CHANGED
@@ -3,77 +3,66 @@
3
  import gradio as gr
4
  import pandas as pd
5
  import torch
 
6
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
7
 
8
- # 1) Load your synthetic profitability dataset
9
  df = pd.read_csv('synthetic_profit.csv')
 
 
10
 
11
- # 2) Ensure numeric columns for true aggregation
12
- for col in ["Revenue", "Profit", "ProfitMargin"]:
13
- df[col] = pd.to_numeric(df[col], errors='coerce')
14
 
15
- # 3) Build the schema description text
16
- # ← replaced .iteritems() with .items() here
17
- schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.items()]
18
- schema_text = "Table schema:\n" + "\n".join(schema_lines)
19
-
20
- # 4) Few-shot examples teaching SUM and AVERAGE patterns
21
- example_block = """
22
- Example 1
23
- Q: Total profit by region?
24
- A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34
25
-
26
- Example 2
27
- Q: Average profit margin for Product B in Americas?
28
- A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18
29
- """.strip()
30
-
31
- # 5) Model & pipeline setup
32
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
33
  device = 0 if torch.cuda.is_available() else -1
34
 
35
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
36
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
37
 
38
- table_qa = pipeline(
39
- "table-question-answering",
40
  model=model,
41
  tokenizer=tokenizer,
42
  framework="pt",
43
  device=device,
 
 
44
  )
45
 
46
- # 6) QA function with schema-aware prompting
47
  def answer_profitability(question: str) -> str:
48
- # cast all cells to string for safety
49
- table = df.astype(str).to_dict(orient="records")
50
-
51
- # assemble the full prompt
52
- prompt = f"""{schema_text}
53
-
54
- {example_block}
55
-
56
- Q: {question}
57
- A:"""
58
 
 
59
  try:
60
- out = table_qa(table=table, query=prompt)
61
- return out.get("answer", "No answer found.")
 
 
 
62
  except Exception as e:
63
- return f"Error: {e}"
 
64
 
65
- # 7) Gradio interface
66
  iface = gr.Interface(
67
  fn=answer_profitability,
68
- inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
69
- outputs="text",
70
- title="SAP Profitability Q&A (Schema-Aware TAPEX)",
71
  description=(
72
- "Every query is prefixed with your table’s schema and two few-shot examples, "
73
- "so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code."
74
  )
75
  )
76
 
77
- # 8) Launch the app
78
  if __name__ == "__main__":
79
  iface.launch(server_name="0.0.0.0", server_port=7860)
 
3
  import gradio as gr
4
  import pandas as pd
5
  import torch
6
+ import duckdb
7
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
8
 
9
+ # 1) Load data and register it in DuckDB
10
  df = pd.read_csv('synthetic_profit.csv')
11
+ conn = duckdb.connect(database=':memory:')
12
+ conn.register('sap', df)
13
 
14
+ # 2) Build a one-line schema description
15
+ schema = ", ".join(df.columns) # e.g. "Region, Product, FiscalYear, ..."
 
16
 
17
+ # 3) Load TAPEX (WikiSQL) for SQL generation
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
19
  device = 0 if torch.cuda.is_available() else -1
20
 
21
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
22
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
23
 
24
+ sql_generator = pipeline(
25
+ "text2text-generation",
26
  model=model,
27
  tokenizer=tokenizer,
28
  framework="pt",
29
  device=device,
30
+ # limit length so it doesn’t try to output the entire table!
31
+ max_length=128,
32
  )
33
 
34
+ # 4) Your new QA function
35
  def answer_profitability(question: str) -> str:
36
+ # 4a) Prompt the model to generate SQL
37
+ prompt = (
38
+ f"Translate to SQL for table `sap` with columns ({schema}):\n"
39
+ f"Question: {question}\n"
40
+ "SQL:"
41
+ )
42
+ sql = sql_generator(prompt)[0]['generated_text'].strip()
 
 
 
43
 
44
+ # 4b) Execute the generated SQL and return results
45
  try:
46
+ result_df = conn.execute(sql).df()
47
+ # pretty-print as text
48
+ if result_df.empty:
49
+ return f"No rows returned. Generated SQL was:\n{sql}"
50
+ return result_df.to_string(index=False)
51
  except Exception as e:
52
+ # if something goes wrong, show you the SQL so you can debug
53
+ return f"Error executing SQL: {e}\n\nGenerated SQL:\n{sql}"
54
 
55
+ # 5) Gradio interface
56
  iface = gr.Interface(
57
  fn=answer_profitability,
58
+ inputs=gr.Textbox(lines=2, placeholder="Ask about your SAP data…"),
59
+ outputs="textbox",
60
+ title="SAP Profitability Q&A (SQL-Generation)",
61
  description=(
62
+ "Uses TAPEX to translate your natural-language question "
63
+ "into a SQL query over the `sap` table, then runs it via DuckDB."
64
  )
65
  )
66
 
 
67
  if __name__ == "__main__":
68
  iface.launch(server_name="0.0.0.0", server_port=7860)