File size: 6,191 Bytes
c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 495c53e c35bbf4 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import sqlparse
import psutil
import os
# Check available memory
def get_available_memory():
return psutil.virtual_memory().available
model_name = "defog/llama-3-sqlcoder-8b"
# Initialize tokenizer
tokenizer = AutoTokenizer.from_pretrained(model_name)
# CPU-compatible model loading
def load_model():
try:
available_memory = get_available_memory()
print(f"Available memory: {available_memory / 1e9:.1f} GB")
# For CPU deployment, we'll use float32 or float16 without quantization
if available_memory > 16e9: # 16GB+ RAM
print("Loading model in float16...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
torch_dtype=torch.float16,
device_map="cpu",
use_cache=True,
low_cpu_mem_usage=True
)
else:
print("Loading model in float32 with low memory usage...")
model = AutoModelForCausalLM.from_pretrained(
model_name,
trust_remote_code=True,
device_map="cpu",
use_cache=True,
low_cpu_mem_usage=True,
torch_dtype=torch.float32
)
return model
except Exception as e:
print(f"Error loading model: {e}")
return None
# Load model (this will take some time on first run)
print("Loading model... This may take a few minutes on CPU.")
model = load_model()
prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|>
Generate a SQL query to answer this question: `{question}`
DDL statements:
CREATE TABLE expenses (
id INTEGER PRIMARY KEY, -- Unique ID for each expense
date DATE NOT NULL, -- Date when the expense occurred
amount DECIMAL(10,2) NOT NULL, -- Amount spent
category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.)
description TEXT, -- Optional description of the expense
payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer)
user_id INTEGER -- ID of the user who made the expense
);
CREATE TABLE categories (
id INTEGER PRIMARY KEY, -- Unique ID for each category
name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.)
description TEXT -- Optional description of the category
);
CREATE TABLE users (
id INTEGER PRIMARY KEY, -- Unique ID for each user
username VARCHAR(50) UNIQUE NOT NULL, -- Username
email VARCHAR(100) UNIQUE NOT NULL, -- Email address
created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created
);
CREATE TABLE budgets (
id INTEGER PRIMARY KEY, -- Unique ID for each budget
user_id INTEGER, -- ID of the user who set the budget
category VARCHAR(50), -- Category for which budget is set
amount DECIMAL(10,2) NOT NULL, -- Budget amount
period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly)
start_date DATE, -- Budget start date
end_date DATE -- Budget end date
);
-- expenses.user_id can be joined with users.id
-- expenses.category can be joined with categories.name
-- budgets.user_id can be joined with users.id
-- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|>
The following SQL query best answers the question `{question}`:
```sql
"""
def generate_query(question):
if model is None:
return "Error: Model not loaded properly"
try:
updated_prompt = prompt_template.format(question=question)
inputs = tokenizer(updated_prompt, return_tensors="pt")
# Generate on CPU
with torch.no_grad():
generated_ids = model.generate(
**inputs,
num_return_sequences=1,
eos_token_id=tokenizer.eos_token_id,
pad_token_id=tokenizer.eos_token_id,
max_new_tokens=400,
do_sample=False,
num_beams=1,
temperature=0.0,
top_p=1,
)
outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True)
# Extract SQL from output
if "```sql" in outputs[0]:
sql_part = outputs[0].split("```sql")[1].split("```")[0].strip()
else:
# Fallback extraction
sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip()
if sql_part.startswith("`"):
sql_part = sql_part[1:]
if "```" in sql_part:
sql_part = sql_part.split("```")[0].strip()
# Clean up the SQL
if sql_part.endswith(";"):
sql_part = sql_part[:-1]
# Format the SQL
formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper')
return formatted_sql
except Exception as e:
return f"Error generating query: {str(e)}"
def gradio_interface(question):
if not question.strip():
return "Please enter a question."
return generate_query(question)
# Create Gradio interface
iface = gr.Interface(
fn=gradio_interface,
inputs=gr.Textbox(
label="Question",
placeholder="Enter your question (e.g., 'Show me all expenses for food category')",
lines=3
),
outputs=gr.Code(label="Generated SQL Query", language="sql"),
title="SQL Query Generator",
description="Generate SQL queries from natural language questions about expense tracking database.",
examples=[
["Show me all expenses for food category"],
["What's the total amount spent on transport this month?"],
["Insert a new expense of 50 dollars for groceries on 2024-01-15"],
["Find users who spent more than 1000 dollars total"],
["Show me the budget vs actual spending for each category"]
],
cache_examples=False
)
if __name__ == "__main__":
iface.launch() |