# app.py import gradio as gr import pandas as pd import torch from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline # 1) Load your synthetic profitability dataset df = pd.read_csv('synthetic_profit.csv') # 2) Ensure numeric columns for true aggregation (optional, but helps you verify sums) for col in ["Revenue", "Profit", "ProfitMargin"]: df[col] = pd.to_numeric(df[col], errors='coerce') # 3) Build the schema description text schema_lines = [f"- {col}: {dtype.name}" for col, dtype in df.dtypes.iteritems()] schema_text = "Table schema:\n" + "\n".join(schema_lines) # 4) Few-shot examples teaching SUM and AVERAGE patterns example_block = """ Example 1 Q: Total profit by region? A: Group “Profit” by “Region” and sum → EMEA: 30172183.37; APAC: 32301788.32; Latin America: 27585378.50; North America: 25473893.34 Example 2 Q: Average profit margin for Product B in Americas? A: Filter Product=B & Region=Americas, take mean of “ProfitMargin” → 0.18 """.strip() # 5) Model & pipeline setup MODEL_ID = "microsoft/tapex-base-finetuned-wikisql" device = 0 if torch.cuda.is_available() else -1 tokenizer = AutoTokenizer.from_pretrained(MODEL_ID) model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_ID) table_qa = pipeline( "table-question-answering", model=model, tokenizer=tokenizer, framework="pt", device=device, ) # 6) QA function with schema-aware prompting def answer_profitability(question: str) -> str: # 6a) cast all cells to string for safety table = df.astype(str).to_dict(orient="records") # 6b) assemble the full prompt prompt = f"""{schema_text} {example_block} Q: {question} A:""" # 6c) call TAPEX try: out = table_qa(table=table, query=prompt) return out.get("answer", "No answer found.") except Exception as e: return f"Error: {e}" # 7) Gradio interface iface = gr.Interface( fn=answer_profitability, inputs=gr.Textbox(lines=2, placeholder="Ask a question about profitability…"), outputs="text", title="SAP Profitability Q&A (Schema-Aware TAPEX)", description=( "Every query is prefixed with your table’s schema and two few-shot examples, " "so the model learns to SUM, AVERAGE, FILTER, etc., without extra Python code." ) ) if __name__ == "__main__": iface.launch(server_name="0.0.0.0", server_port=7860)