from fastapi import FastAPI, HTTPException, Header, Request from fastapi.responses import JSONResponse, HTMLResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import sqlite3 import sqlparse import os import uuid import json # Added for json.load() from typing import Dict, Any from app.database import create_session_db, close_session_db from app.schemas import RunQueryRequest, ValidateQueryRequest app = FastAPI() app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") @app.exception_handler(Exception) async def custom_exception_handler(request: Request, exc: Exception): return JSONResponse(status_code=500, content={"detail": str(exc)}) sessions: Dict[str, dict] = {} BASE_DIR = os.path.dirname(__file__) def load_questions(domain: str): file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json") if not os.path.exists(file_path): raise FileNotFoundError(f"Question file not found: {file_path}") with open(file_path, "r") as f: return json.load(f) # Replaced eval with json.load() def load_schema_sql(domain: str): file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql") if not os.path.exists(file_path): raise FileNotFoundError(f"Schema file not found: {file_path}") with open(file_path, "r") as f: return f.read() def is_safe_query(sql: str) -> bool: parsed = sqlparse.parse(sql.lower())[0] return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"]) def extract_tables(sql: str) -> list: tables = set() tokens = sql.replace("\n", " ").lower().split() in_subquery = in_openquery = in_values = False for i, token in enumerate(tokens): if token == "(" and not in_subquery and not in_values: in_values = i > 0 and tokens[i - 1] == "values" in_subquery = not in_values if token == ")" and (in_subquery or in_values): if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False elif in_subquery: in_subquery = False if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True if token == ")" and in_openquery: in_openquery = False if in_openquery: continue if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]: next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else "" if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]: if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token) i += 1 elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into": next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token) i += 2 while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1 if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token) elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into": next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" if next_token and next_token not in ["from", "select"]: tables.add(next_token) i += 2 while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1 if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token) elif token == "with": while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 if i + 2 < len(tokens) and tokens[i + 2] == "(": bracket_count = 1 subquery_start = i + 2 i += 2 while i < len(tokens) and bracket_count > 0: if tokens[i] == "(": bracket_count += 1 elif tokens[i] == ")": bracket_count -= 1 i += 1 if bracket_count == 0 and i > subquery_start: subquery = " ".join(tokens[subquery_start:i - 1]) tables.update(t for t in extract_tables(subquery) if t not in tables) elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(": while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias) elif token in ["exists", "in"]: subquery_start = i + 1 while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1 if i > subquery_start: subquery = " ".join(tokens[subquery_start:i + 1]) tables.update(t for t in extract_tables(subquery) if t not in tables) return list(tables) @app.post("/api/session") async def create_session(): session_id = str(uuid.uuid4()) sessions[session_id] = {"conn": create_session_db(), "domain": None} return {"session_id": session_id} @app.get("/api/databases") async def get_databases(): questions_dir = os.path.join(BASE_DIR, "questions") return {"databases": [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")] if os.path.exists(questions_dir) else []} @app.post("/api/load-schema/{domain}") async def load_schema(domain: str, session_id: str = Header(...)): if session_id not in sessions: raise HTTPException(status_code=401, detail="Invalid session") sessions[session_id] = {"conn": create_session_db(), "domain": domain} try: sessions[session_id]["conn"].executescript(load_schema_sql(domain)) sessions[session_id]["conn"].commit() except sqlite3.Error as e: close_session_db(sessions[session_id]["conn"]) # Cleanup on error del sessions[session_id] raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") return {"message": f"Database {domain} loaded"} @app.get("/api/schema/{domain}") async def get_schema(domain: str, session_id: str = Header(...)): if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}} @app.get("/api/sample-data/{domain}") async def get_sample_data(domain: str, session_id: str = Header(...)): if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}} @app.post("/api/run-query") async def run_query(request: RunQueryRequest, session_id: str = Header(...)): if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute(request.query) if cursor.description: columns = [desc[0] for desc in cursor.description] return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]} return {"message": "Query executed successfully (no results)"} @app.get("/api/questions/{domain}") async def get_questions(domain: str, difficulty: str = None): questions = load_questions(domain) if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()] return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions] @app.post("/api/validate") async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)): if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") conn = sessions[session_id]["conn"] cursor = conn.cursor() cursor.execute(request.user_query) user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] cursor.execute(request.expected_query) expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""} @app.on_event("shutdown") async def cleanup(): for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id] @app.get("/", response_class=HTMLResponse) async def serve_frontend(): file_path = os.path.join(BASE_DIR, "static", "index.html") if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}") with open(file_path, "r") as f: return f.read()