talk_to_data / app.py
PD03's picture
Update app.py
ba55f08 verified
raw
history blame
2.57 kB
import os
import gradio as gr
import pandas as pd
import duckdb
import openai
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
import openai.error
# β€” Load OpenAI key β€”
openai.api_key = os.getenv("OPENAI_API_KEY")
# β€” Prepare DuckDB β€”
df = pd.read_csv("synthetic_profit.csv")
conn = duckdb.connect(":memory:"); conn.register("sap", df)
schema = ", ".join(df.columns)
# β€” Prepare HF fallback pipeline once β€”
HF_MODEL = "google/flan-t5-small"
hf_tok = AutoTokenizer.from_pretrained(HF_MODEL)
hf_mod = AutoModelForSeq2SeqLM.from_pretrained(HF_MODEL)
hf_gen = pipeline("text2text-generation", model=hf_mod, tokenizer=hf_tok, device=-1)
def generate_sql(question: str) -> str:
prompt = (
f"You are an expert SQL generator for DuckDB table `sap` with columns: {schema}.\n"
f"Translate the user’s question into a valid SQL query. Return ONLY the SQL."
)
try:
resp = openai.chat.completions.create(
model="gpt-3.5-turbo",
messages=[
{"role":"system","content":prompt},
{"role":"user","content":question}
],
temperature=0.0,
max_tokens=150,
)
sql = resp.choices[0].message.content.strip()
except openai.error.InvalidRequestError as e:
# catch non-quota OpenAI errors here if you want
raise
except openai.error.RateLimitError as e:
# 429 fallback to Hugging Face
fallback_prompt = f"Translate to SQL over `sap({schema})`:\n{question}"
sql = hf_gen(fallback_prompt, max_length=128)[0]["generated_text"]
# strip ``` fences if present
if sql.startswith("```") and sql.endswith("```"):
sql = "\n".join(sql.splitlines()[1:-1])
return sql
def answer_profitability(question: str) -> str:
sql = generate_sql(question)
try:
out_df = conn.execute(sql).df()
except Exception as e:
return f"❌ SQL error:\n{e}\n\n```sql\n{sql}\n```"
if out_df.empty:
return f"No results.\n\n```sql\n{sql}\n```"
if out_df.shape == (1,1):
return str(out_df.iat[0,0])
return out_df.to_markdown(index=False)
iface = gr.Interface(
fn=answer_profitability,
inputs=gr.Textbox(lines=2, label="Question"),
outputs=gr.Textbox(lines=8, label="Answer"),
title="SAP Profitability Q&A",
description="Uses OpenAI β†’ DuckDB, falling back to Flan-T5-Small on 429s.",
allow_flagging="never",
)
if __name__=="__main__":
iface.launch(server_name="0.0.0.0", server_port=7860)