|
import gradio as gr |
|
import torch |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import sqlparse |
|
import psutil |
|
import os |
|
|
|
|
|
def get_available_memory(): |
|
return psutil.virtual_memory().available |
|
|
|
model_name = "defog/llama-3-sqlcoder-8b" |
|
|
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
|
|
def load_model(): |
|
try: |
|
available_memory = get_available_memory() |
|
print(f"Available memory: {available_memory / 1e9:.1f} GB") |
|
|
|
|
|
if available_memory > 16e9: |
|
print("Loading model in float16...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
torch_dtype=torch.float16, |
|
device_map="cpu", |
|
use_cache=True, |
|
low_cpu_mem_usage=True |
|
) |
|
else: |
|
print("Loading model in float32 with low memory usage...") |
|
model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
trust_remote_code=True, |
|
device_map="cpu", |
|
use_cache=True, |
|
low_cpu_mem_usage=True, |
|
torch_dtype=torch.float32 |
|
) |
|
|
|
return model |
|
except Exception as e: |
|
print(f"Error loading model: {e}") |
|
return None |
|
|
|
|
|
print("Loading model... This may take a few minutes on CPU.") |
|
model = load_model() |
|
|
|
prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|> |
|
|
|
Generate a SQL query to answer this question: `{question}` |
|
|
|
DDL statements: |
|
|
|
CREATE TABLE expenses ( |
|
id INTEGER PRIMARY KEY, -- Unique ID for each expense |
|
date DATE NOT NULL, -- Date when the expense occurred |
|
amount DECIMAL(10,2) NOT NULL, -- Amount spent |
|
category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.) |
|
description TEXT, -- Optional description of the expense |
|
payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer) |
|
user_id INTEGER -- ID of the user who made the expense |
|
); |
|
|
|
CREATE TABLE categories ( |
|
id INTEGER PRIMARY KEY, -- Unique ID for each category |
|
name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.) |
|
description TEXT -- Optional description of the category |
|
); |
|
|
|
CREATE TABLE users ( |
|
id INTEGER PRIMARY KEY, -- Unique ID for each user |
|
username VARCHAR(50) UNIQUE NOT NULL, -- Username |
|
email VARCHAR(100) UNIQUE NOT NULL, -- Email address |
|
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created |
|
); |
|
|
|
CREATE TABLE budgets ( |
|
id INTEGER PRIMARY KEY, -- Unique ID for each budget |
|
user_id INTEGER, -- ID of the user who set the budget |
|
category VARCHAR(50), -- Category for which budget is set |
|
amount DECIMAL(10,2) NOT NULL, -- Budget amount |
|
period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly) |
|
start_date DATE, -- Budget start date |
|
end_date DATE -- Budget end date |
|
); |
|
|
|
-- expenses.user_id can be joined with users.id |
|
-- expenses.category can be joined with categories.name |
|
-- budgets.user_id can be joined with users.id |
|
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|> |
|
|
|
The following SQL query best answers the question `{question}`: |
|
```sql |
|
""" |
|
|
|
def generate_query(question): |
|
if model is None: |
|
return "Error: Model not loaded properly" |
|
|
|
try: |
|
updated_prompt = prompt_template.format(question=question) |
|
inputs = tokenizer(updated_prompt, return_tensors="pt") |
|
|
|
|
|
with torch.no_grad(): |
|
generated_ids = model.generate( |
|
**inputs, |
|
num_return_sequences=1, |
|
eos_token_id=tokenizer.eos_token_id, |
|
pad_token_id=tokenizer.eos_token_id, |
|
max_new_tokens=400, |
|
do_sample=False, |
|
num_beams=1, |
|
temperature=0.0, |
|
top_p=1, |
|
) |
|
|
|
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) |
|
|
|
|
|
if "```sql" in outputs[0]: |
|
sql_part = outputs[0].split("```sql")[1].split("```")[0].strip() |
|
else: |
|
|
|
sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip() |
|
if sql_part.startswith("`"): |
|
sql_part = sql_part[1:] |
|
if "```" in sql_part: |
|
sql_part = sql_part.split("```")[0].strip() |
|
|
|
|
|
if sql_part.endswith(";"): |
|
sql_part = sql_part[:-1] |
|
|
|
|
|
formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper') |
|
return formatted_sql |
|
|
|
except Exception as e: |
|
return f"Error generating query: {str(e)}" |
|
|
|
def gradio_interface(question): |
|
if not question.strip(): |
|
return "Please enter a question." |
|
|
|
return generate_query(question) |
|
|
|
|
|
iface = gr.Interface( |
|
fn=gradio_interface, |
|
inputs=gr.Textbox( |
|
label="Question", |
|
placeholder="Enter your question (e.g., 'Show me all expenses for food category')", |
|
lines=3 |
|
), |
|
outputs=gr.Code(label="Generated SQL Query", language="sql"), |
|
title="SQL Query Generator", |
|
description="Generate SQL queries from natural language questions about expense tracking database.", |
|
examples=[ |
|
["Show me all expenses for food category"], |
|
["What's the total amount spent on transport this month?"], |
|
["Insert a new expense of 50 dollars for groceries on 2024-01-15"], |
|
["Find users who spent more than 1000 dollars total"], |
|
["Show me the budget vs actual spending for each category"] |
|
], |
|
cache_examples=False |
|
) |
|
|
|
if __name__ == "__main__": |
|
iface.launch() |