PD03 commited on
Commit
0e84c33
·
verified ·
1 Parent(s): 0f2082e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -82
app.py CHANGED
@@ -1,98 +1,52 @@
1
  import gradio as gr
2
  import pandas as pd
3
- import duckdb
4
- import torch
5
- from transformers import T5Tokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
- # Load data
8
  df = pd.read_csv("synthetic_profit.csv")
9
- conn = duckdb.connect(":memory:")
10
- conn.register("sap", df)
11
- schema = ", ".join(df.columns)
12
-
13
- # Model & tokenizer
14
- MODEL_ID = "mrm8488/t5-base-finetuned-wikisql"
15
- device = 0 if torch.cuda.is_available() else -1
16
- tokenizer = T5Tokenizer.from_pretrained(MODEL_ID)
17
- model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
18
- sql_generator = pipeline(
19
- "text2text-generation",
20
- model=model,
21
- tokenizer=tokenizer,
22
- framework="pt",
23
- device=device,
24
- max_length=128,
25
  )
26
 
27
- # Prompt→SQL with few-shot
28
- def generate_sql(question: str) -> str:
29
- prompt = f"""
30
- -- DuckDB table `sap` columns: {schema}
31
-
32
- -- EXAMPLE 1
33
- -- Q: What is the total profit by region?
34
- -- SQL:
35
- SELECT
36
- Region,
37
- SUM(Profit) AS total_profit
38
- FROM sap
39
- GROUP BY Region;
40
-
41
- -- EXAMPLE 2
42
- -- Q: What is the total revenue for Product A in EMEA in Q1 2024?
43
- -- SQL:
44
- SELECT
45
- SUM(Revenue) AS total_revenue
46
- FROM sap
47
- WHERE
48
- Product = 'Product A'
49
- AND Region = 'EMEA'
50
- AND FiscalYear = 2024
51
- AND FiscalQuarter = 'Q1';
52
-
53
- -- NOW
54
- Q: {question}
55
- SQL:
56
- """.strip()
57
-
58
- out = sql_generator(prompt)[0]["generated_text"].strip()
59
- if "SELECT" in out.upper():
60
- sql = out[out.upper().index("SELECT"):]
61
- if ";" in sql:
62
- sql = sql[: sql.rindex(";") + 1]
63
- return sql
64
- raise ValueError(f"Did not generate a SELECT; got:\n{out}")
65
 
66
- # NL→SQL→DuckDB→Result
67
- def answer_profitability(question: str) -> str:
68
- try:
69
- sql = generate_sql(question)
70
- except Exception as e:
71
- return f"❌ Prompt/SQL error:\n{e}"
72
 
73
- try:
74
- rel = conn.execute(sql)
75
- if rel is None:
76
- return f"❌ No relation returned for SQL:\n```sql\n{sql}\n```"
77
- df_out = rel.df()
78
- except Exception as e:
79
- return f"❌ SQL execution error:\n{e}\n\nGenerated SQL:\n```sql\n{sql}\n```"
80
 
81
- if df_out.empty:
82
- return f"No rows.\n\n```sql\n{sql}\n```"
83
- if df_out.shape == (1,1):
84
- return str(df_out.iat[0,0])
85
- return df_out.to_markdown(index=False)
86
 
87
- # Gradio UI
88
  iface = gr.Interface(
89
- fn=answer_profitability,
90
- inputs=gr.Textbox(lines=2, label="Question"),
91
- outputs=gr.Textbox(lines=8, label="Answer"),
92
  title="SAP Profitability Q&A",
93
- description="Translate English → SQL → DuckDB → Answer",
 
 
 
94
  allow_flagging="never",
95
  )
96
 
97
  if __name__ == "__main__":
98
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
1
  import gradio as gr
2
  import pandas as pd
3
+ from transformers import pipeline
 
 
4
 
5
+ # 1) Load your synthetic SAP data
6
  df = pd.read_csv("synthetic_profit.csv")
7
+ table = df.astype(str).to_dict(orient="records")
8
+
9
+ # 2) Table‐QA pipeline with TAPEX
10
+ qa = pipeline(
11
+ "table-question-answering",
12
+ model="microsoft/tapex-base-finetuned-wtq",
13
+ tokenizer="microsoft/tapex-base-finetuned-wtq",
14
+ device=-1
 
 
 
 
 
 
 
 
15
  )
16
 
17
+ # 3) Three fewshot examples
18
+ EXAMPLE_PROMPT = """
19
+ Example 1:
20
+ Q: What is the total revenue for Product A in EMEA in Q1 2024?
21
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Revenue → 3075162.49
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ Example 2:
24
+ Q: What is the total cost for Product A in EMEA in Q1 2024?
25
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum Cost → 2894321.75
 
 
 
26
 
27
+ Example 3:
28
+ Q: What is the total margin for Product A in EMEA in Q1 2024?
29
+ A: Filter Product=A & Region=EMEA & FiscalYear=2024 & FiscalQuarter=Q1, then sum ProfitMargin → 0.18
30
+ """
 
 
 
31
 
32
+ def answer_question(question: str) -> str:
33
+ # Prepend the examples to teach the model your pattern
34
+ full_query = EXAMPLE_PROMPT + f"\nQ: {question}\nA:"
35
+ result = qa(table=table, query=full_query)
36
+ return result.get("answer", "No answer found.")
37
 
38
+ # 4) Gradio UI
39
  iface = gr.Interface(
40
+ fn=answer_question,
41
+ inputs=gr.Textbox(lines=2, placeholder="Ask a basic question…", label="Your question"),
42
+ outputs=gr.Textbox(lines=4, label="Answer"),
43
  title="SAP Profitability Q&A",
44
+ description=(
45
+ "Ask simple revenue/cost/margin questions on the synthetic SAP data. "
46
+ "Powered by microsoft/tapex-base-finetuned-wtq with three few‐shot examples."
47
+ ),
48
  allow_flagging="never",
49
  )
50
 
51
  if __name__ == "__main__":
52
+ iface.launch(server_name="0.0.0.0", server_port=7860)