File size: 1,590 Bytes
a5ece8b
aa97025
8be1581
519d64c
8be1581
887b999
5bd8928
aa97025
519d64c
 
a5ece8b
5bd8928
 
e9af0b4
5bd8928
8be1581
5bd8928
8be1581
aa97025
 
6a97111
519d64c
 
aa97025
 
 
519d64c
 
60fddfe
 
8be1581
79c9d08
519d64c
5bd8928
519d64c
 
 
 
8cc354c
79c9d08
67fc297
5bd8928
67fc297
79c9d08
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
import gradio as gr
import pandas as pd
import torch
import duckdb
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# Load data into DuckDB
df = pd.read_csv('synthetic_profit.csv')
con = duckdb.connect(':memory:')
con.register('sap', df)

# One-line schema for prompts
schema = ", ".join(df.columns)

# Load TAPEX for SQL generation
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"
device = 0 if torch.cuda.is_available() else -1

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

sql_gen = pipeline(
    "text2text-generation",
    model=model,
    tokenizer=tokenizer,
    framework="pt",
    device=device,
    max_length=128,
)

def answer_profitability(question: str) -> str:
    # 1) Generate SQL
    prompt = (
        f"-- Translate to SQL for table `sap` ({schema})\n"
        f"Question: {question}\n"
        "SQL:"
    )
    sql = sql_gen(prompt)[0]['generated_text'].strip()

    # 2) Try to execute it
    try:
        df_out = con.execute(sql).df()
    except Exception as e:
        # Use a normal f-string with explicit \n for newlines
        return (
            f"❌ **SQL Error**\n"
            f"```\n{e}\n```\n\n"
            f"**Generated SQL**\n"
            f"```sql\n{sql}\n```"
        )

    # 3) Format successful result
    if df_out.empty:
        return (
            "No rows returned.\n\n"
            f"**Generated SQL**\n```sql\n{sql}\n```"
        )

    if df_out.shape == (1,1):
        return str(df_out.iat[0,0])

    return df_out.to_markdown(index=False)