PD03 commited on
Commit
5bd8928
·
verified ·
1 Parent(s): 519d64c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -36
app.py CHANGED
@@ -4,17 +4,17 @@ import torch
4
  import duckdb
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
- # 1) Load data into DuckDB
8
  df = pd.read_csv('synthetic_profit.csv')
9
  con = duckdb.connect(':memory:')
10
  con.register('sap', df)
11
 
12
- # 2) Build a one-line schema for prompting
13
- schema = ", ".join(df.columns) # e.g. "Region,Product,FiscalYear,...."
14
 
15
- # 3) Load TAPEX-WikiSQL as a text2text generator
16
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
17
- device = 0 if torch.cuda.is_available() else -1
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
@@ -28,43 +28,17 @@ sql_gen = pipeline(
28
  max_length=128,
29
  )
30
 
31
- # 4) Core QA fn: NL → SQL → execute → return result
32
  def answer_profitability(question: str) -> str:
33
- # a) Prompt TAPEX to generate SQL
34
  prompt = (
35
- f"-- Translate to SQL over table `sap` with columns ({schema})\n"
36
  f"Question: {question}\n"
37
  "SQL:"
38
  )
39
  sql = sql_gen(prompt)[0]['generated_text'].strip()
40
 
41
- # b) Execute the generated SQL
42
  try:
43
- result_df = con.execute(sql).df()
44
  except Exception as e:
45
- return f" SQL Error: {e}\n\nGenerated SQL:\n{sql}"
46
-
47
- # c) Format the output
48
- if result_df.empty:
49
- return f"No rows returned.\n\nGenerated SQL:\n{sql}"
50
-
51
- # If it's a single cell result, just return that number
52
- if result_df.shape == (1,1):
53
- return str(result_df.iat[0,0])
54
- # Otherwise pretty-print the DataFrame
55
- return result_df.to_string(index=False)
56
-
57
- # 5) Gradio UI
58
- iface = gr.Interface(
59
- fn=answer_profitability,
60
- inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
61
- outputs="text",
62
- title="SAP Profitability Q&A (SQL-Generation)",
63
- description=(
64
- "TAPEX converts your natural-language query into SQL,\n"
65
- "then runs it via DuckDB—no hard-coded fallbacks."
66
- )
67
- )
68
-
69
- if __name__ == "__main__":
70
- iface.launch(server_name="0.0.0.0", server_port=7860)
 
4
  import duckdb
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
 
7
+ # Load data into DuckDB
8
  df = pd.read_csv('synthetic_profit.csv')
9
  con = duckdb.connect(':memory:')
10
  con.register('sap', df)
11
 
12
+ # One-line schema for prompts
13
+ schema = ", ".join(df.columns)
14
 
15
+ # Load TAPEX for SQL generation
16
  MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
17
+ device = 0 if torch.cuda.is_available() else -1
18
 
19
  tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
20
  model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)
 
28
  max_length=128,
29
  )
30
 
 
31
  def answer_profitability(question: str) -> str:
32
+ # 1) Ask TAPEX to write SQL
33
  prompt = (
34
+ f"-- Translate to SQL for table `sap` ({schema})\n"
35
  f"Question: {question}\n"
36
  "SQL:"
37
  )
38
  sql = sql_gen(prompt)[0]['generated_text'].strip()
39
 
40
+ # 2) Try to run it
41
  try:
42
+ df_out = con.execute(sql).df()
43
  except Exception as e:
44
+ return f"""**❌ SQL Error**