Spaces:
Sleeping
Sleeping
import logging | |
import time | |
import uvicorn | |
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel | |
from contextlib import asynccontextmanager | |
from typing import List, Dict, Any | |
# Import necessary components from your kig_core library | |
# Ensure kig_core is in the Python path or installed as a package | |
try: | |
from kig_core.config import settings # Loads config on import | |
from kig_core.schemas import PlannerState, KeyIssue as KigKeyIssue, GraphConfig | |
from kig_core.planner import build_graph | |
from kig_core.graph_client import neo4j_client # Import the initialized client instance | |
from langchain_core.messages import HumanMessage | |
except ImportError as e: | |
print(f"Error importing kig_core components: {e}") | |
print("Please ensure kig_core is in your Python path or installed.") | |
# You might want to exit or raise a clearer error if imports fail | |
raise | |
# Configure logging for the API | |
logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(name)s - %(levelname)s - %(message)s') | |
logger = logging.getLogger(__name__) | |
# --- Pydantic Models for API Request/Response --- | |
class KeyIssueRequest(BaseModel): | |
"""Request body containing the user's technical query.""" | |
query: str | |
class KeyIssueResponse(BaseModel): | |
"""Response body containing the generated key issues.""" | |
key_issues: List[KigKeyIssue] # Use the KeyIssue schema from kig_core | |
# --- Global Variables / State --- | |
# Keep the graph instance global for efficiency if desired, | |
# but consider potential concurrency issues if graph/LLMs have state. | |
# Rebuilding on each request is safer for statelessness. | |
app_graph = None # Will be initialized at startup | |
# --- Application Lifecycle (Startup/Shutdown) --- | |
async def lifespan(app: FastAPI): | |
"""Handles startup and shutdown events.""" | |
global app_graph | |
logger.info("API starting up...") | |
# Initialize Neo4j client (already done on import by graph_client.py) | |
# Verify connection (optional, already done by graph_client on init) | |
try: | |
logger.info("Verifying Neo4j connection...") | |
neo4j_client._get_driver().verify_connectivity() | |
logger.info("Neo4j connection verified.") | |
except Exception as e: | |
logger.error(f"Neo4j connection verification failed on startup: {e}", exc_info=True) | |
# Decide if the app should fail to start | |
# raise RuntimeError("Failed to connect to Neo4j on startup.") from e | |
# Build the LangGraph application | |
logger.info("Building LangGraph application...") | |
try: | |
app_graph = build_graph() | |
logger.info("LangGraph application built successfully.") | |
except Exception as e: | |
logger.error(f"Failed to build LangGraph application on startup: {e}", exc_info=True) | |
# Decide if the app should fail to start | |
raise RuntimeError("Failed to build LangGraph on startup.") from e | |
yield # API runs here | |
# --- Shutdown --- | |
logger.info("API shutting down...") | |
# Close Neo4j connection (handled by atexit in graph_client.py) | |
# neo4j_client.close() # Usually not needed due to atexit registration | |
logger.info("Neo4j client closed (likely via atexit).") | |
logger.info("API shutdown complete.") | |
# --- FastAPI Application --- | |
app = FastAPI( | |
title="Key Issue Generator API", | |
description="API to generate Key Issues based on a technical query using LLMs and Neo4j.", | |
version="1.0.0", | |
lifespan=lifespan # Use the lifespan context manager | |
) | |
# --- API Endpoint --- | |
# API state check route | |
def read_root(): | |
return {"status": "ok"} | |
async def generate_issues(request: KeyIssueRequest): | |
""" | |
Accepts a technical query and returns a list of generated Key Issues. | |
""" | |
global app_graph | |
if app_graph is None: | |
logger.error("Graph application is not initialized.") | |
raise HTTPException(status_code=503, detail="Service Unavailable: Graph not initialized") | |
user_query = request.query | |
if not user_query: | |
raise HTTPException(status_code=400, detail="Query cannot be empty.") | |
logger.info(f"Received request to generate key issues for query: '{user_query[:100]}...'") | |
start_time = time.time() | |
try: | |
# --- Prepare Initial State for LangGraph --- | |
# Note: Ensure PlannerState aligns with what build_graph expects | |
initial_state: PlannerState = { | |
"user_query": user_query, | |
"messages": [HumanMessage(content=user_query)], | |
"plan": [], | |
"current_plan_step_index": -1, # Or as expected by your graph's entry point | |
"step_outputs": {}, | |
"key_issues": [], | |
"error": None | |
} | |
# --- Define Configuration (e.g., Thread ID for Memory) --- | |
# Using a simple thread ID; adapt if using persistent memory | |
# import hashlib | |
# thread_id = hashlib.sha256(user_query.encode()).hexdigest()[:8] | |
# config: GraphConfig = {"configurable": {"thread_id": thread_id}} | |
# If not using memory, config can be simpler or empty based on LangGraph version | |
config: GraphConfig = {"configurable": {}} # Adjust if thread_id/memory is needed | |
# --- Execute the LangGraph Workflow --- | |
logger.info("Invoking LangGraph workflow...") | |
# Use invoke for a single result, or stream if you need intermediate steps | |
final_state = await app_graph.ainvoke(initial_state, config=config) | |
# If using stream: | |
# final_state = None | |
# async for step_state in app_graph.astream(initial_state, config=config): | |
# # Process intermediate states if needed | |
# node_name = list(step_state.keys())[0] | |
# logger.debug(f"Graph step completed: {node_name}") | |
# final_state = step_state[node_name] # Get the latest full state output | |
end_time = time.time() | |
logger.info(f"Workflow finished in {end_time - start_time:.2f} seconds.") | |
# --- Process Final Results --- | |
if final_state is None: | |
logger.error("Workflow execution did not produce a final state.") | |
raise HTTPException(status_code=500, detail="Workflow execution failed to produce a result.") | |
if final_state.get("error"): | |
error_msg = final_state.get("error", "Unknown error") | |
logger.error(f"Workflow failed with error: {error_msg}") | |
# Map internal errors to appropriate HTTP status codes | |
status_code = 500 # Internal Server Error by default | |
if "Neo4j" in error_msg or "connection" in error_msg.lower(): | |
status_code = 503 # Service Unavailable (database issue) | |
elif "LLM error" in error_msg or "parse" in error_msg.lower(): | |
status_code = 502 # Bad Gateway (issue with upstream LLM) | |
raise HTTPException(status_code=status_code, detail=f"Workflow failed: {error_msg}") | |
# --- Extract Key Issues --- | |
# Ensure the structure matches KeyIssueResponse and KigKeyIssue Pydantic model | |
generated_issues_data = final_state.get("key_issues", []) | |
# Validate and convert if necessary (Pydantic usually handles this via response_model) | |
try: | |
# Pydantic will validate against KeyIssueResponse -> List[KigKeyIssue] | |
response_data = {"key_issues": generated_issues_data} | |
logger.info(f"Successfully generated {len(generated_issues_data)} key issues.") | |
return response_data | |
except Exception as pydantic_error: # Catch potential validation errors | |
logger.error(f"Failed to validate final key issues against response model: {pydantic_error}", exc_info=True) | |
logger.error(f"Data that failed validation: {generated_issues_data}") | |
raise HTTPException(status_code=500, detail="Internal error: Failed to format key issues response.") | |
except HTTPException as http_exc: | |
# Re-raise HTTPExceptions directly | |
raise http_exc | |
except ConnectionError as e: | |
logger.error(f"Connection Error during API request: {e}", exc_info=True) | |
raise HTTPException(status_code=503, detail=f"Service Unavailable: {e}") | |
except ValueError as e: | |
logger.error(f"Value Error during API request: {e}", exc_info=True) | |
raise HTTPException(status_code=400, detail=f"Bad Request: {e}") # Often input validation issues | |
except Exception as e: | |
logger.error(f"An unexpected error occurred during API request: {e}", exc_info=True) | |
raise HTTPException(status_code=500, detail=f"Internal Server Error: An unexpected error occurred.") | |
# --- How to Run --- | |
if __name__ == "__main__": | |
# Make sure to set environment variables for config (NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY, etc.) | |
# or have a .env file in the same directory where you run this script. | |
print("Starting API server...") | |
print("Ensure required environment variables (e.g., NEO4J_URI, NEO4J_PASSWORD, GEMINI_API_KEY) are set or .env file is present.") | |
# Run with uvicorn: uvicorn api:app --reload --host 0.0.0.0 --port 8000 | |
# The --reload flag is good for development. Remove it for production. | |
uvicorn.run("api:app", host="0.0.0.0", port=8000, reload=True) # Use reload=False for production |