Vatsal Goel
move to /data
47a6e35 unverified
"""
SQLite database utilities for the NL to SQL agent
"""
import os
import sqlite3
from typing import List, Dict, Any, Optional, Tuple
import pandas as pd
from data.schemas import schema_map
from pathlib import Path
# Database file path
DB_PATH = Path("/data/hr_database.db")
SQL_DIR = os.path.join(os.path.dirname(__file__), "sql")
def get_connection():
"""Get a connection to the SQLite database"""
return sqlite3.connect(DB_PATH)
def execute_query(query: str) -> Tuple[List[Dict[str, Any]], Optional[str]]:
"""
Execute a SQL query and return the results
Args:
query: SQL query to execute
Returns:
Tuple of (results as list of dicts, error message if any)
"""
try:
conn = get_connection()
# Convert results to pandas DataFrame for easier handling
df = pd.read_sql_query(query, conn)
# Convert DataFrame to list of dictionaries
results = df.to_dict(orient='records')
conn.close()
return results, None
except Exception as e:
return [], str(e)
def execute_sql_file(file_path: str):
"""Execute all SQL statements in a file"""
conn = get_connection()
cursor = conn.cursor()
with open(file_path, 'r') as f:
sql_script = f.read()
# Split the script by semicolons to get individual statements
# Skip empty statements and comments
statements = [stmt.strip() for stmt in sql_script.split(';')
if stmt.strip() and not stmt.strip().startswith('--')]
for statement in statements:
try:
cursor.execute(statement)
except sqlite3.Error as e:
print(f"Error executing statement: {statement}")
print(f"Error message: {e}")
conn.commit()
conn.close()
def setup_database():
"""Set up the database with schema and sample data."""
# Always purge the database first
if DB_PATH.exists():
os.remove(DB_PATH)
print(f"Existing database removed: {DB_PATH}")
conn = get_connection()
cursor = conn.cursor()
# Create tables based on schema definitions
for table_name, schema in schema_map.items():
# Build CREATE TABLE statement
columns = []
primary_key = None
foreign_keys = []
for column in schema.columns:
col_def = f"{column.name} {column.type}"
# Check for primary key
if column.name.endswith('_id') and column.name.startswith(table_name[:-1]):
primary_key = column.name
# Check for foreign keys
if column.name.endswith('_id') and not column.name.startswith(table_name[:-1]):
referenced_table = column.name.replace('_id', '') + 's'
if referenced_table in schema_map:
foreign_keys.append(f"FOREIGN KEY ({column.name}) REFERENCES {referenced_table} ({column.name})")
columns.append(col_def)
# Add primary key constraint
if primary_key:
columns.append(f"PRIMARY KEY ({primary_key})")
# Add foreign key constraints
columns.extend(foreign_keys)
# Create the table
create_table_sql = f"""
CREATE TABLE IF NOT EXISTS {table_name} (
{', '.join(columns)}
)
"""
cursor.execute(create_table_sql)
conn.commit()
conn.close()
# If tables are empty, insert sample data from SQL file
sample_data_path = os.path.join(SQL_DIR, "sample_data.sql")
if os.path.exists(sample_data_path):
print(f"Inserting sample data from {sample_data_path}")
execute_sql_file(sample_data_path)
else:
print(f"Sample data file not found: {sample_data_path}")
print(f"Database setup complete at {DB_PATH}")
def get_available_tables():
"""Get a list of all available tables in the database"""
conn = get_connection()
cursor = conn.cursor()
# Query for all table names
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
tables = [row[0] for row in cursor.fetchall()]
conn.close()
return tables
if __name__ == "__main__":
setup_database()