File size: 4,334 Bytes
2e83155
 
 
 
 
 
 
 
 
 
 
47a6e35
2e83155
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
"""
SQLite database utilities for the NL to SQL agent
"""
import os
import sqlite3
from typing import List, Dict, Any, Optional, Tuple
import pandas as pd
from data.schemas import schema_map
from pathlib import Path

# Database file path
DB_PATH = Path("/data/hr_database.db")
SQL_DIR = os.path.join(os.path.dirname(__file__), "sql")

def get_connection():
    """Get a connection to the SQLite database"""
    return sqlite3.connect(DB_PATH)

def execute_query(query: str) -> Tuple[List[Dict[str, Any]], Optional[str]]:
    """
    Execute a SQL query and return the results
    
    Args:
        query: SQL query to execute
        
    Returns:
        Tuple of (results as list of dicts, error message if any)
    """
    try:
        conn = get_connection()
        # Convert results to pandas DataFrame for easier handling
        df = pd.read_sql_query(query, conn)
        
        # Convert DataFrame to list of dictionaries
        results = df.to_dict(orient='records')
        
        conn.close()
        return results, None
    except Exception as e:
        return [], str(e)

def execute_sql_file(file_path: str):
    """Execute all SQL statements in a file"""
    conn = get_connection()
    cursor = conn.cursor()
    
    with open(file_path, 'r') as f:
        sql_script = f.read()
    
    # Split the script by semicolons to get individual statements
    # Skip empty statements and comments
    statements = [stmt.strip() for stmt in sql_script.split(';') 
                 if stmt.strip() and not stmt.strip().startswith('--')]
    
    for statement in statements:
        try:
            cursor.execute(statement)
        except sqlite3.Error as e:
            print(f"Error executing statement: {statement}")
            print(f"Error message: {e}")
    
    conn.commit()
    conn.close()

def setup_database():
    """Set up the database with schema and sample data."""
    # Always purge the database first
    if DB_PATH.exists():
        os.remove(DB_PATH)
        print(f"Existing database removed: {DB_PATH}")
    
    conn = get_connection()
    cursor = conn.cursor()
    
    # Create tables based on schema definitions
    for table_name, schema in schema_map.items():
        # Build CREATE TABLE statement
        columns = []
        primary_key = None
        foreign_keys = []
        
        for column in schema.columns:
            col_def = f"{column.name} {column.type}"
            
            # Check for primary key
            if column.name.endswith('_id') and column.name.startswith(table_name[:-1]):
                primary_key = column.name
            
            # Check for foreign keys
            if column.name.endswith('_id') and not column.name.startswith(table_name[:-1]):
                referenced_table = column.name.replace('_id', '') + 's'
                if referenced_table in schema_map:
                    foreign_keys.append(f"FOREIGN KEY ({column.name}) REFERENCES {referenced_table} ({column.name})")
            
            columns.append(col_def)
        
        # Add primary key constraint
        if primary_key:
            columns.append(f"PRIMARY KEY ({primary_key})")
        
        # Add foreign key constraints
        columns.extend(foreign_keys)
        
        # Create the table
        create_table_sql = f"""
        CREATE TABLE IF NOT EXISTS {table_name} (
            {', '.join(columns)}
        )
        """
        cursor.execute(create_table_sql)
    
    conn.commit()
    conn.close()
    
    # If tables are empty, insert sample data from SQL file
    sample_data_path = os.path.join(SQL_DIR, "sample_data.sql")
    if os.path.exists(sample_data_path):
        print(f"Inserting sample data from {sample_data_path}")
        execute_sql_file(sample_data_path)
    else:
        print(f"Sample data file not found: {sample_data_path}")
    
    print(f"Database setup complete at {DB_PATH}")

def get_available_tables():
    """Get a list of all available tables in the database"""
    conn = get_connection()
    cursor = conn.cursor()
    
    # Query for all table names
    cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name NOT LIKE 'sqlite_%'")
    tables = [row[0] for row in cursor.fetchall()]
    
    conn.close()
    return tables

if __name__ == "__main__":
    setup_database()