File size: 1,674 Bytes
ef15351
 
a5ece8b
aa97025
8be1581
 
887b999
ef15351
aa97025
a5ece8b
ef15351
8be1581
 
ef15351
8be1581
887b999
ef15351
aa97025
 
6a97111
ef15351
aa97025
 
 
 
 
8be1581
60fddfe
 
ef15351
8be1581
ef15351
 
 
67fc297
aa97025
 
67fc297
aa97025
887b999
ef15351
aa97025
 
 
6a97111
8be1581
ef15351
 
 
 
887b999
a5ece8b
ef15351
aa97025
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
# app.py

import gradio as gr
import pandas as pd
import torch
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline

# 1) Load your synthetic profitability dataset
df = pd.read_csv('synthetic_profit.csv')

# 2) Choose the publicly available TAPEX WikiSQL model
MODEL_ID = "microsoft/tapex-base-finetuned-wikisql"

# 3) Set device: GPU if available, else CPU
device = 0 if torch.cuda.is_available() else -1

# 4) Load tokenizer and model
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model     = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID)

# 5) Build the table-question-answering pipeline
table_qa = pipeline(
    "table-question-answering",
    model=model,
    tokenizer=tokenizer,
    framework="pt",
    device=device,
)

# 6) Define the QA function, casting all cells to strings to avoid float issues
def answer_profitability(question: str) -> str:
    # Cast entire DataFrame to string
    df_str = df.astype(str)
    table  = df_str.to_dict(orient="records")
    try:
        out = table_qa(table=table, query=question)
        return out.get("answer", "No answer found.")
    except Exception as e:
        return f"Error: {e}"

# 7) Define Gradio interface
iface = gr.Interface(
    fn=answer_profitability,
    inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"),
    outputs="text",
    title="SAP Profitability Q&A (TAPEX-Base)",
    description=(
        "Free-form questions on the synthetic profitability dataset, "
        "powered end-to-end by microsoft/tapex-base-finetuned-wikisql."
    )
)

# 8) Launch the app
if __name__ == "__main__":
    iface.launch(server_name="0.0.0.0", server_port=7860)