File size: 2,142 Bytes
ef15351
 
a5ece8b
aa97025
8be1581
996ed5a
8be1581
887b999
996ed5a
aa97025
996ed5a
 
a5ece8b
996ed5a
 
b8b6a66
996ed5a
8be1581
8cc354c
8be1581
aa97025
 
6a97111
996ed5a
 
aa97025
 
 
8be1581
996ed5a
 
60fddfe
 
996ed5a
8be1581
996ed5a
 
 
 
 
 
 
8cc354c
996ed5a
67fc297
996ed5a
 
 
 
 
67fc297
996ed5a
 
887b999
996ed5a
aa97025
 
996ed5a
 
 
ef15351
996ed5a
 
ef15351
887b999
a5ece8b
aa97025
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
# app.py

import gradio as gr
import pandas as pd
import torch
import duckdb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# 1) Load data and register it in DuckDB
df = pd.read_csv('synthetic_profit.csv')
conn = duckdb.connect(database=':memory:')
conn.register('sap', df)

# 2) Build a one-line schema description
schema = ", ".join(df.columns)  # e.g. "Region, Product, FiscalYear, ..."

# 3) Load TAPEX (WikiSQL) for SQL generation
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device   = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

sql_generator = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    framework="pt",
    device=device,
    # limit length so it doesn’t try to output the entire table!
    max_length=128,
)

# 4) Your new QA function
def answer_profitability(question: str) -> str:
    # 4a) Prompt the model to generate SQL
    prompt = (
        f"Translate to SQL for table `sap` with columns ({schema}):\n"
        f"Question: {question}\n"
        "SQL:"
    )
    sql = sql_generator(prompt)[0]['generated_text'].strip()

    # 4b) Execute the generated SQL and return results
    try:
        result_df = conn.execute(sql).df()
        # pretty-print as text
        if result_df.empty:
            return f"No rows returned. Generated SQL was:\n{sql}"
        return result_df.to_string(index=False)
    except Exception as e:
        # if something goes wrong, show you the SQL so you can debug
        return f"Error executing SQL: {e}\n\nGenerated SQL:\n{sql}"

# 5) Gradio interface
iface = gr.Interface(
    fn=answer_profitability,
    inputs=gr.Textbox(lines=2, placeholder="Ask about your SAP data…"),
    outputs="textbox",
    title="SAP Profitability Q&A (SQL-Generation)",
    description=(
        "Uses TAPEX to translate your natural-language question "
        "into a SQL query over the `sap` table, then runs it via DuckDB."
    )
)

if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)