text2sql / app.py
Sid26Roy's picture
Update app.py
c35bbf4 verified
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import psutil
import os
# Check available memory
def get_available_memory():
return psutil.virtual_memory().available
model_name = "defog/llama-3-sqlcoder-8b"
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# CPU-compatible model loading
def load_model():
try:
available_memory = get_available_memory()
print(f"Available memory: {available_memory / 1e9:.1f} GB")
# For CPU deployment, we'll use float32 or float16 without quantization
if available_memory > 16e9: # 16GB+ RAM
print("Loading model in float16...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="cpu",
use_cache=True,
low_cpu_mem_usage=True
)
else:
print("Loading model in float32 with low memory usage...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="cpu",
use_cache=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
)
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Load model (this will take some time on first run)
print("Loading model... This may take a few minutes on CPU.")
model = load_model()
prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{question}`
DDL statements:
CREATE TABLE expenses (
id INTEGER PRIMARY KEY, -- Unique ID for each expense
date DATE NOT NULL, -- Date when the expense occurred
amount DECIMAL(10,2) NOT NULL, -- Amount spent
category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.)
description TEXT, -- Optional description of the expense
payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer)
user_id INTEGER -- ID of the user who made the expense
);
CREATE TABLE categories (
id INTEGER PRIMARY KEY, -- Unique ID for each category
name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.)
description TEXT -- Optional description of the category
);
CREATE TABLE users (
id INTEGER PRIMARY KEY, -- Unique ID for each user
username VARCHAR(50) UNIQUE NOT NULL, -- Username
email VARCHAR(100) UNIQUE NOT NULL, -- Email address
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created
);
CREATE TABLE budgets (
id INTEGER PRIMARY KEY, -- Unique ID for each budget
user_id INTEGER, -- ID of the user who set the budget
category VARCHAR(50), -- Category for which budget is set
amount DECIMAL(10,2) NOT NULL, -- Budget amount
period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly)
start_date DATE, -- Budget start date
end_date DATE -- Budget end date
);
-- expenses.user_id can be joined with users.id
-- expenses.category can be joined with categories.name
-- budgets.user_id can be joined with users.id
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question `{question}`:
```sql
"""
def generate_query(question):
if model is None:
return "Error: Model not loaded properly"
try:
updated_prompt = prompt_template.format(question=question)
inputs = tokenizer(updated_prompt, return_tensors="pt")
# Generate on CPU
with torch.no_grad():
generated_ids = model.generate(
**inputs,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=400,
do_sample=False,
num_beams=1,
temperature=0.0,
top_p=1,
)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract SQL from output
if "```sql" in outputs[0]:
sql_part = outputs[0].split("```sql")[1].split("```")[0].strip()
else:
# Fallback extraction
sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip()
if sql_part.startswith("`"):
sql_part = sql_part[1:]
if "```" in sql_part:
sql_part = sql_part.split("```")[0].strip()
# Clean up the SQL
if sql_part.endswith(";"):
sql_part = sql_part[:-1]
# Format the SQL
formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper')
return formatted_sql
except Exception as e:
return f"Error generating query: {str(e)}"
def gradio_interface(question):
if not question.strip():
return "Please enter a question."
return generate_query(question)
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(
label="Question",
placeholder="Enter your question (e.g., 'Show me all expenses for food category')",
lines=3
),
outputs=gr.Code(label="Generated SQL Query", language="sql"),
title="SQL Query Generator",
description="Generate SQL queries from natural language questions about expense tracking database.",
examples=[
["Show me all expenses for food category"],
["What's the total amount spent on transport this month?"],
["Insert a new expense of 50 dollars for groceries on 2024-01-15"],
["Find users who spent more than 1000 dollars total"],
["Show me the budget vs actual spending for each category"]
],
cache_examples=False
)
if __name__ == "__main__":
iface.launch()