|
from fastapi import FastAPI, HTTPException, Header, Request |
|
from fastapi.responses import JSONResponse, HTMLResponse |
|
from fastapi.staticfiles import StaticFiles |
|
from pydantic import BaseModel |
|
import sqlite3 |
|
import sqlparse |
|
import os |
|
import uuid |
|
import json |
|
from typing import Dict, Any |
|
from app.database import create_session_db, close_session_db |
|
from app.schemas import RunQueryRequest, ValidateQueryRequest |
|
|
|
app = FastAPI() |
|
app.mount("/static", StaticFiles(directory=os.path.join(os.path.dirname(__file__), "static")), name="static") |
|
|
|
@app.exception_handler(Exception) |
|
async def custom_exception_handler(request: Request, exc: Exception): |
|
return JSONResponse(status_code=500, content={"detail": str(exc)}) |
|
|
|
sessions: Dict[str, dict] = {} |
|
|
|
BASE_DIR = os.path.dirname(__file__) |
|
|
|
def load_questions(domain: str): |
|
file_path = os.path.join(BASE_DIR, "questions", f"{domain}.json") |
|
if not os.path.exists(file_path): |
|
raise FileNotFoundError(f"Question file not found: {file_path}") |
|
with open(file_path, "r") as f: |
|
return json.load(f) |
|
|
|
def load_schema_sql(domain: str): |
|
file_path = os.path.join(BASE_DIR, "schemas", f"{domain}.sql") |
|
if not os.path.exists(file_path): |
|
raise FileNotFoundError(f"Schema file not found: {file_path}") |
|
with open(file_path, "r") as f: |
|
return f.read() |
|
|
|
def is_safe_query(sql: str) -> bool: |
|
parsed = sqlparse.parse(sql.lower())[0] |
|
return str(parsed).lower().strip().startswith("select") and all(kw not in str(parsed).lower() for kw in ["drop", "attach", "detach", "pragma", "insert", "update", "delete"]) |
|
|
|
def extract_tables(sql: str) -> list: |
|
tables = set() |
|
tokens = sql.replace("\n", " ").lower().split() |
|
in_subquery = in_openquery = in_values = False |
|
|
|
for i, token in enumerate(tokens): |
|
if token == "(" and not in_subquery and not in_values: |
|
in_values = i > 0 and tokens[i - 1] == "values" |
|
in_subquery = not in_values |
|
if token == ")" and (in_subquery or in_values): |
|
if in_values and i + 1 < len(tokens) and tokens[i + 1] == "as": in_values = False |
|
elif in_subquery: in_subquery = False |
|
if token == "openquery" and i + 1 < len(tokens) and tokens[i + 1] == "(": in_openquery = True |
|
if token == ")" and in_openquery: in_openquery = False |
|
if in_openquery: continue |
|
|
|
if token in ["from", "join", "update", "delete", "insert", "into", "using", "apply", "pivot", "table"]: |
|
next_token = tokens[i + 1].replace(",);", "") if i + 1 < len(tokens) else "" |
|
if next_token and next_token not in ["select", "where", "on", "order", "group", "having", "as", "("]: |
|
if i + 2 < len(tokens) and tokens[i + 2] == "as": next_token = next_token |
|
elif next_token not in ["left", "right", "inner", "outer", "cross", "full"]: tables.add(next_token) |
|
i += 1 |
|
elif token == "merge" and i + 1 < len(tokens) and tokens[i + 1] == "into": |
|
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" |
|
if next_token and next_token not in ["using", "select", "where"]: tables.add(next_token) |
|
i += 2 |
|
while i + 1 < len(tokens) and tokens[i + 1] != "using": i += 1 |
|
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["select", "where"]: tables.add(next_token) |
|
elif token == "select" and i + 1 < len(tokens) and tokens[i + 1] == "into": |
|
next_token = tokens[i + 2].replace(",);", "") if i + 2 < len(tokens) else "" |
|
if next_token and next_token not in ["from", "select"]: tables.add(next_token) |
|
i += 2 |
|
while i + 1 < len(tokens) and tokens[i + 1] != "from": i += 1 |
|
if i + 2 < len(tokens) and (next_token := tokens[i + 2].replace(",);", "")) and next_token not in ["where", "join"]: tables.add(next_token) |
|
elif token == "with": |
|
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 |
|
if i + 2 < len(tokens) and tokens[i + 2] == "(": |
|
bracket_count = 1 |
|
subquery_start = i + 2 |
|
i += 2 |
|
while i < len(tokens) and bracket_count > 0: |
|
if tokens[i] == "(": bracket_count += 1 |
|
elif tokens[i] == ")": bracket_count -= 1 |
|
i += 1 |
|
if bracket_count == 0 and i > subquery_start: |
|
subquery = " ".join(tokens[subquery_start:i - 1]) |
|
tables.update(t for t in extract_tables(subquery) if t not in tables) |
|
elif token == "values" and i + 1 < len(tokens) and tokens[i + 1] == "(": |
|
while i + 1 < len(tokens) and tokens[i + 1] != "as": i += 1 |
|
if i + 2 < len(tokens) and (alias := tokens[i + 2].replace(",);", "")): tables.add(alias) |
|
elif token in ["exists", "in"]: |
|
subquery_start = i + 1 |
|
while i + 1 < len(tokens) and tokens[i + 1] != ")": i += 1 |
|
if i > subquery_start: |
|
subquery = " ".join(tokens[subquery_start:i + 1]) |
|
tables.update(t for t in extract_tables(subquery) if t not in tables) |
|
return list(tables) |
|
|
|
@app.post("/api/session") |
|
async def create_session(): |
|
session_id = str(uuid.uuid4()) |
|
sessions[session_id] = {"conn": create_session_db(), "domain": None} |
|
return {"session_id": session_id} |
|
|
|
@app.get("/api/databases") |
|
async def get_databases(): |
|
questions_dir = os.path.join(BASE_DIR, "questions") |
|
return {"databases": [f.replace(".json", "") for f in os.listdir(questions_dir) if f.endswith(".json")] if os.path.exists(questions_dir) else []} |
|
|
|
@app.post("/api/load-schema/{domain}") |
|
async def load_schema(domain: str, session_id: str = Header(...)): |
|
if session_id not in sessions: raise HTTPException(status_code=401, detail="Invalid session") |
|
sessions[session_id] = {"conn": create_session_db(), "domain": domain} |
|
try: |
|
sessions[session_id]["conn"].executescript(load_schema_sql(domain)) |
|
sessions[session_id]["conn"].commit() |
|
except sqlite3.Error as e: |
|
close_session_db(sessions[session_id]["conn"]) |
|
del sessions[session_id] |
|
raise HTTPException(status_code=500, detail=f"Database error: {str(e)}") |
|
return {"message": f"Database {domain} loaded"} |
|
|
|
@app.get("/api/schema/{domain}") |
|
async def get_schema(domain: str, session_id: str = Header(...)): |
|
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") |
|
conn = sessions[session_id]["conn"] |
|
cursor = conn.cursor() |
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
return {"schema": {table: [{"name": row["name"], "type": row["type"]} for row in conn.execute(f"PRAGMA table_info({table});")] for table in [row["name"] for row in cursor.fetchall()]}} |
|
|
|
@app.get("/api/sample-data/{domain}") |
|
async def get_sample_data(domain: str, session_id: str = Header(...)): |
|
if session_id not in sessions or sessions[session_id]["domain"] != domain: raise HTTPException(status_code=401, detail="Invalid session or domain not loaded") |
|
conn = sessions[session_id]["conn"] |
|
cursor = conn.cursor() |
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table';") |
|
return {"sample_data": {table: {"columns": [desc[0] for desc in conn.execute(f"SELECT * FROM {table} LIMIT 5").description], "rows": [dict(row) for row in conn.execute(f"SELECT * FROM {table} LIMIT 5")]} for table in [row["name"] for row in cursor.fetchall()]}} |
|
|
|
@app.post("/api/run-query") |
|
async def run_query(request: RunQueryRequest, session_id: str = Header(...)): |
|
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") |
|
if not is_safe_query(request.query): raise HTTPException(status_code=400, detail="Only SELECT queries are allowed") |
|
conn = sessions[session_id]["conn"] |
|
cursor = conn.cursor() |
|
cursor.execute(request.query) |
|
if cursor.description: |
|
columns = [desc[0] for desc in cursor.description] |
|
return {"columns": columns, "rows": [dict(zip(columns, row)) for row in cursor.fetchall()]} |
|
return {"message": "Query executed successfully (no results)"} |
|
|
|
@app.get("/api/questions/{domain}") |
|
async def get_questions(domain: str, difficulty: str = None): |
|
questions = load_questions(domain) |
|
if difficulty: questions = [q for q in questions if q["difficulty"].lower() == difficulty.lower()] |
|
return [{"id": q["id"], "title": q["title"], "difficulty": q["difficulty"], "description": q["description"], "hint": q["hint"], "expected_sql": q["expected_sql"]} for q in questions] |
|
|
|
@app.post("/api/validate") |
|
async def validate_query(request: ValidateQueryRequest, session_id: str = Header(...)): |
|
if session_id not in sessions or not sessions[session_id]["domain"]: raise HTTPException(status_code=401, detail="Invalid session or no database loaded") |
|
conn = sessions[session_id]["conn"] |
|
cursor = conn.cursor() |
|
cursor.execute(request.user_query) |
|
user_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] |
|
cursor.execute(request.expected_query) |
|
expected_result = [tuple(str(x).lower() for x in row) for row in cursor.fetchall()] if cursor.description else [] |
|
return {"valid": user_result == expected_result, "error": "Results do not match." if user_result != expected_result else ""} |
|
|
|
@app.on_event("shutdown") |
|
async def cleanup(): |
|
for session_id in list(sessions): close_session_db(sessions[session_id]["conn"]); del sessions[session_id] |
|
|
|
@app.get("/", response_class=HTMLResponse) |
|
async def serve_frontend(): |
|
file_path = os.path.join(BASE_DIR, "static", "index.html") |
|
if not os.path.exists(file_path): raise HTTPException(status_code=500, detail=f"Frontend file not found: {file_path}") |
|
with open(file_path, "r") as f: return f.read() |