Spaces:
Sleeping
Sleeping
import logging | |
import os | |
from contextlib import asynccontextmanager | |
from fastapi import FastAPI, HTTPException | |
from fastapi.middleware.cors import CORSMiddleware | |
from pydantic import BaseModel | |
from inference_server.session_manager import SessionManager | |
# Configure logging | |
logging.basicConfig( | |
level=logging.INFO, | |
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
) | |
logger = logging.getLogger(__name__) | |
# Global session manager | |
session_manager = SessionManager() | |
async def lifespan(app: FastAPI): | |
"""Handle app startup and shutdown.""" | |
logger.info("🚀 RobotHub Inference Server starting up...") | |
yield | |
logger.info("🔄 RobotHub Inference Server shutting down...") | |
await session_manager.cleanup_all_sessions() | |
logger.info("✅ RobotHub Inference Server shutdown complete") | |
# FastAPI app | |
app = FastAPI( | |
title="RobotHub Inference Server", | |
description="Multi-Policy Model Inference Server for Real-time Robot Control", | |
version="1.0.0", | |
lifespan=lifespan, | |
) | |
# Add CORS middleware | |
app.add_middleware( | |
CORSMiddleware, | |
allow_origins=["*"], # In production, specify actual origins | |
allow_credentials=True, | |
allow_methods=["*"], | |
allow_headers=["*"], | |
) | |
# Request/Response models | |
class CreateSessionRequest(BaseModel): | |
session_id: str | |
policy_path: str | |
transport_server_url: str | |
camera_names: list[str] = ["front"] # Support multiple cameras | |
workspace_id: str | None = None # Optional workspace ID | |
policy_type: str = "act" # Policy type: act, pi0, pi0fast, smolvla, diffusion | |
language_instruction: str | None = None # For vision-language policies | |
class CreateSessionResponse(BaseModel): | |
workspace_id: str | |
camera_room_ids: dict[str, str] # {camera_name: room_id} | |
joint_input_room_id: str | |
joint_output_room_id: str | |
class SessionStatusResponse(BaseModel): | |
session_id: str | |
status: str | |
policy_path: str | |
policy_type: str | |
camera_names: list[str] | |
workspace_id: str | |
rooms: dict | |
stats: dict | |
inference_stats: dict | None = None | |
error_message: str | None = None | |
# Health check endpoints | |
async def root(): | |
"""Health check endpoint.""" | |
return {"message": "RobotHub Inference Server is running", "status": "healthy"} | |
async def health_check(): | |
"""Detailed health check.""" | |
return { | |
"status": "healthy", | |
"active_sessions": len(session_manager.sessions), | |
"session_ids": list(session_manager.sessions.keys()), | |
} | |
# Session management endpoints | |
async def create_session(request: CreateSessionRequest): | |
""" | |
Create a new inference session. | |
If workspace_id is provided, all rooms will be created in that workspace. | |
If workspace_id is not provided, a new workspace will be generated automatically. | |
All rooms for a session (cameras + joints) are always created in the same workspace. | |
""" | |
try: | |
room_ids = await session_manager.create_session( | |
session_id=request.session_id, | |
policy_path=request.policy_path, | |
camera_names=request.camera_names, | |
transport_server_url=request.transport_server_url, | |
workspace_id=request.workspace_id, | |
policy_type=request.policy_type, | |
language_instruction=request.language_instruction, | |
) | |
return CreateSessionResponse(**room_ids) | |
except ValueError as e: | |
raise HTTPException(status_code=400, detail=str(e)) | |
except Exception as e: | |
logger.exception(f"Failed to create session {request.session_id}") | |
raise HTTPException(status_code=500, detail=f"Failed to create session: {e!s}") | |
async def list_sessions(): | |
"""List all sessions.""" | |
sessions = await session_manager.list_sessions() | |
return [SessionStatusResponse(**session) for session in sessions] | |
# Session control endpoints | |
async def start_inference(session_id: str): | |
"""Start inference for a session.""" | |
try: | |
await session_manager.start_inference(session_id) | |
except KeyError: | |
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
except Exception as e: | |
logger.exception(f"Failed to start inference for session {session_id}") | |
raise HTTPException(status_code=500, detail=f"Failed to start inference: {e!s}") | |
else: | |
return {"message": f"Inference started for session {session_id}"} | |
async def stop_inference(session_id: str): | |
"""Stop inference for a session.""" | |
try: | |
await session_manager.stop_inference(session_id) | |
except KeyError: | |
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
else: | |
return {"message": f"Inference started for session {session_id}"} | |
async def restart_inference(session_id: str): | |
"""Restart inference for a session.""" | |
try: | |
await session_manager.restart_inference(session_id) | |
except KeyError: | |
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
except Exception as e: | |
logger.exception(f"Failed to restart inference for session {session_id}") | |
raise HTTPException( | |
status_code=500, detail=f"Failed to restart inference: {e!s}" | |
) | |
else: | |
return {"message": f"Inference restarted for session {session_id}"} | |
async def delete_session(session_id: str): | |
"""Delete a session.""" | |
try: | |
await session_manager.delete_session(session_id) | |
except KeyError: | |
raise HTTPException(status_code=404, detail=f"Session {session_id} not found") | |
else: | |
return {"message": f"Session {session_id} deleted"} | |
# Main entry point | |
if __name__ == "__main__": | |
import uvicorn | |
port = int(os.environ.get("PORT", "8001")) | |
uvicorn.run( | |
"inference_server.main:app", | |
host="0.0.0.0", | |
port=port, | |
reload=False, | |
log_level="info", | |
) | |