Spaces:
Sleeping
Sleeping
import os | |
import gradio as gr | |
import pandas as pd | |
import duckdb | |
import openai | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline | |
import openai.error | |
# β Load OpenAI key β | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
# β Prepare DuckDB β | |
df = pd.read_csv("synthetic_profit.csv") | |
conn = duckdb.connect(":memory:"); conn.register("sap", df) | |
schema = ", ".join(df.columns) | |
# β Prepare HF fallback pipeline once β | |
HF_MODEL = "google/flan-t5-small" | |
hf_tok = AutoTokenizer.from_pretrained(HF_MODEL) | |
hf_mod = AutoModelForSeq2SeqLM.from_pretrained(HF_MODEL) | |
hf_gen = pipeline("text2text-generation", model=hf_mod, tokenizer=hf_tok, device=-1) | |
def generate_sql(question: str) -> str: | |
prompt = ( | |
f"You are an expert SQL generator for DuckDB table `sap` with columns: {schema}.\n" | |
f"Translate the userβs question into a valid SQL query. Return ONLY the SQL." | |
) | |
try: | |
resp = openai.chat.completions.create( | |
model="gpt-3.5-turbo", | |
messages=[ | |
{"role":"system","content":prompt}, | |
{"role":"user","content":question} | |
], | |
temperature=0.0, | |
max_tokens=150, | |
) | |
sql = resp.choices[0].message.content.strip() | |
except openai.error.InvalidRequestError as e: | |
# catch non-quota OpenAI errors here if you want | |
raise | |
except openai.error.RateLimitError as e: | |
# 429 fallback to Hugging Face | |
fallback_prompt = f"Translate to SQL over `sap({schema})`:\n{question}" | |
sql = hf_gen(fallback_prompt, max_length=128)[0]["generated_text"] | |
# strip ``` fences if present | |
if sql.startswith("```") and sql.endswith("```"): | |
sql = "\n".join(sql.splitlines()[1:-1]) | |
return sql | |
def answer_profitability(question: str) -> str: | |
sql = generate_sql(question) | |
try: | |
out_df = conn.execute(sql).df() | |
except Exception as e: | |
return f"β SQL error:\n{e}\n\n```sql\n{sql}\n```" | |
if out_df.empty: | |
return f"No results.\n\n```sql\n{sql}\n```" | |
if out_df.shape == (1,1): | |
return str(out_df.iat[0,0]) | |
return out_df.to_markdown(index=False) | |
iface = gr.Interface( | |
fn=answer_profitability, | |
inputs=gr.Textbox(lines=2, label="Question"), | |
outputs=gr.Textbox(lines=8, label="Answer"), | |
title="SAP Profitability Q&A", | |
description="Uses OpenAI β DuckDB, falling back to Flan-T5-Small on 429s.", | |
allow_flagging="never", | |
) | |
if __name__=="__main__": | |
iface.launch(server_name="0.0.0.0", server_port=7860) |