import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import sqlparse import psutil import os # Check available memory def get_available_memory(): return psutil.virtual_memory().available model_name = "defog/llama-3-sqlcoder-8b" # Initialize tokenizer tokenizer = AutoTokenizer.from_pretrained(model_name) # CPU-compatible model loading def load_model(): try: available_memory = get_available_memory() print(f"Available memory: {available_memory / 1e9:.1f} GB") # For CPU deployment, we'll use float32 or float16 without quantization if available_memory > 16e9: # 16GB+ RAM print("Loading model in float16...") model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, torch_dtype=torch.float16, device_map="cpu", use_cache=True, low_cpu_mem_usage=True ) else: print("Loading model in float32 with low memory usage...") model = AutoModelForCausalLM.from_pretrained( model_name, trust_remote_code=True, device_map="cpu", use_cache=True, low_cpu_mem_usage=True, torch_dtype=torch.float32 ) return model except Exception as e: print(f"Error loading model: {e}") return None # Load model (this will take some time on first run) print("Loading model... This may take a few minutes on CPU.") model = load_model() prompt_template = """<|begin_of_text|><|start_header_id|>user<|end_header_id|> Generate a SQL query to answer this question: `{question}` DDL statements: CREATE TABLE expenses ( id INTEGER PRIMARY KEY, -- Unique ID for each expense date DATE NOT NULL, -- Date when the expense occurred amount DECIMAL(10,2) NOT NULL, -- Amount spent category VARCHAR(50) NOT NULL, -- Category of expense (food, transport, utilities, etc.) description TEXT, -- Optional description of the expense payment_method VARCHAR(20), -- How the payment was made (cash, credit_card, debit_card, bank_transfer) user_id INTEGER -- ID of the user who made the expense ); CREATE TABLE categories ( id INTEGER PRIMARY KEY, -- Unique ID for each category name VARCHAR(50) UNIQUE NOT NULL, -- Category name (food, transport, utilities, entertainment, etc.) description TEXT -- Optional description of the category ); CREATE TABLE users ( id INTEGER PRIMARY KEY, -- Unique ID for each user username VARCHAR(50) UNIQUE NOT NULL, -- Username email VARCHAR(100) UNIQUE NOT NULL, -- Email address created_at TIMESTAMP DEFAULT CURRENT_TIMESTAMP -- When the user account was created ); CREATE TABLE budgets ( id INTEGER PRIMARY KEY, -- Unique ID for each budget user_id INTEGER, -- ID of the user who set the budget category VARCHAR(50), -- Category for which budget is set amount DECIMAL(10,2) NOT NULL, -- Budget amount period VARCHAR(20) DEFAULT 'monthly', -- Budget period (daily, weekly, monthly, yearly) start_date DATE, -- Budget start date end_date DATE -- Budget end date ); -- expenses.user_id can be joined with users.id -- expenses.category can be joined with categories.name -- budgets.user_id can be joined with users.id -- budgets.category can be joined with categories.name<|eot_id|><|start_header_id|>assistant<|end_header_id|> The following SQL query best answers the question `{question}`: ```sql """ def generate_query(question): if model is None: return "Error: Model not loaded properly" try: updated_prompt = prompt_template.format(question=question) inputs = tokenizer(updated_prompt, return_tensors="pt") # Generate on CPU with torch.no_grad(): generated_ids = model.generate( **inputs, num_return_sequences=1, eos_token_id=tokenizer.eos_token_id, pad_token_id=tokenizer.eos_token_id, max_new_tokens=400, do_sample=False, num_beams=1, temperature=0.0, top_p=1, ) outputs = tokenizer.batch_decode(generated_ids, skip_special_tokens=True) # Extract SQL from output if "```sql" in outputs[0]: sql_part = outputs[0].split("```sql")[1].split("```")[0].strip() else: # Fallback extraction sql_part = outputs[0].split("The following SQL query best answers the question")[1].strip() if sql_part.startswith("`"): sql_part = sql_part[1:] if "```" in sql_part: sql_part = sql_part.split("```")[0].strip() # Clean up the SQL if sql_part.endswith(";"): sql_part = sql_part[:-1] # Format the SQL formatted_sql = sqlparse.format(sql_part, reindent=True, keyword_case='upper') return formatted_sql except Exception as e: return f"Error generating query: {str(e)}" def gradio_interface(question): if not question.strip(): return "Please enter a question." return generate_query(question) # Create Gradio interface iface = gr.Interface( fn=gradio_interface, inputs=gr.Textbox( label="Question", placeholder="Enter your question (e.g., 'Show me all expenses for food category')", lines=3 ), outputs=gr.Code(label="Generated SQL Query", language="sql"), title="SQL Query Generator", description="Generate SQL queries from natural language questions about expense tracking database.", examples=[ ["Show me all expenses for food category"], ["What's the total amount spent on transport this month?"], ["Insert a new expense of 50 dollars for groceries on 2024-01-15"], ["Find users who spent more than 1000 dollars total"], ["Show me the budget vs actual spending for each category"] ], cache_examples=False ) if __name__ == "__main__": iface.launch()