Spaces:
Sleeping
Sleeping
| """ | |
| SQLite database utilities for the NL to SQL agent | |
| """ | |
| import os | |
| import sqlite3 | |
| from typing import List, Dict, Any, Optional, Tuple | |
| import pandas as pd | |
| from data.schemas import schema_map | |
| from pathlib import Path | |
| # Database file path | |
| DB_PATH = Path("/data/hr_database.db") | |
| SQL_DIR = os.path.join(os.path.dirname(__file__), "sql") | |
| def get_connection(): | |
| """Get a connection to the SQLite database""" | |
| return sqlite3.connect(DB_PATH) | |
| def execute_query(query: str) -> Tuple[List[Dict[str, Any]], Optional[str]]: | |
| """ | |
| Execute a SQL query and return the results | |
| Args: | |
| query: SQL query to execute | |
| Returns: | |
| Tuple of (results as list of dicts, error message if any) | |
| """ | |
| try: | |
| conn = get_connection() | |
| # Convert results to pandas DataFrame for easier handling | |
| df = pd.read_sql_query(query, conn) | |
| # Convert DataFrame to list of dictionaries | |
| results = df.to_dict(orient='records') | |
| conn.close() | |
| return results, None | |
| except Exception as e: | |
| return [], str(e) | |
| def execute_sql_file(file_path: str): | |
| """Execute all SQL statements in a file""" | |
| conn = get_connection() | |
| cursor = conn.cursor() | |
| with open(file_path, 'r') as f: | |
| sql_script = f.read() | |
| # Split the script by semicolons to get individual statements | |
| # Skip empty statements and comments | |
| statements = [stmt.strip() for stmt in sql_script.split(';') | |
| if stmt.strip() and not stmt.strip().startswith('--')] | |
| for statement in statements: | |
| try: | |
| cursor.execute(statement) | |
| except sqlite3.Error as e: | |
| print(f"Error executing statement: {statement}") | |
| print(f"Error message: {e}") | |
| conn.commit() | |
| conn.close() | |
| def setup_database(): | |
| """Set up the database with schema and sample data.""" | |
| # Always purge the database first | |
| if DB_PATH.exists(): | |
| os.remove(DB_PATH) | |
| print(f"Existing database removed: {DB_PATH}") | |
| conn = get_connection() | |
| cursor = conn.cursor() | |
| # Create tables based on schema definitions | |
| for table_name, schema in schema_map.items(): | |
| # Build CREATE TABLE statement | |
| columns = [] | |
| primary_key = None | |
| foreign_keys = [] | |
| for column in schema.columns: | |
| col_def = f"{column.name} {column.type}" | |
| # Check for primary key | |
| if column.name.endswith('_id') and column.name.startswith(table_name[:-1]): | |
| primary_key = column.name | |
| # Check for foreign keys | |
| if column.name.endswith('_id') and not column.name.startswith(table_name[:-1]): | |
| referenced_table = column.name.replace('_id', '') + 's' | |
| if referenced_table in schema_map: | |
| foreign_keys.append(f"FOREIGN KEY ({column.name}) REFERENCES {referenced_table} ({column.name})") | |
| columns.append(col_def) | |
| # Add primary key constraint | |
| if primary_key: | |
| columns.append(f"PRIMARY KEY ({primary_key})") | |
| # Add foreign key constraints | |
| columns.extend(foreign_keys) | |
| # Create the table | |
| create_table_sql = f""" | |
| CREATE TABLE IF NOT EXISTS {table_name} ( | |
| {', '.join(columns)} | |
| ) | |
| """ | |
| cursor.execute(create_table_sql) | |
| conn.commit() | |
| conn.close() | |
| # If tables are empty, insert sample data from SQL file | |
| sample_data_path = os.path.join(SQL_DIR, "sample_data.sql") | |
| if os.path.exists(sample_data_path): | |
| print(f"Inserting sample data from {sample_data_path}") | |
| execute_sql_file(sample_data_path) | |
| else: | |
| print(f"Sample data file not found: {sample_data_path}") | |
| print(f"Database setup complete at {DB_PATH}") | |
| def get_available_tables(): | |
| """Get a list of all available tables in the database""" | |
| conn = get_connection() | |
| cursor = conn.cursor() | |
| # Query for all table names | |
| cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'") | |
| tables = [row[0] for row in cursor.fetchall()] | |
| conn.close() | |
| return tables | |
| if __name__ == "__main__": | |
| setup_database() |