from fastapi import FastAPI, HTTPException, BackgroundTasks from pydantic import BaseModel, Field from typing import List, Optional, Dict, Any import httpx import asyncio import logging import time import json # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI( title="Ollama API Server", description="REST API for running Ollama models", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Ollama server configuration OLLAMA_BASE_URL = "http://localhost:11434" # Pydantic models class ChatMessage(BaseModel): role: str = Field(..., description="Role of the message sender (user, assistant, system)") content: str = Field(..., description="Content of the message") class ChatRequest(BaseModel): model: str = Field(..., description="Model name to use for chat") messages: List[ChatMessage] = Field(..., description="List of chat messages") temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") stream: Optional[bool] = Field(False, description="Whether to stream the response") class GenerateRequest(BaseModel): model: str = Field(..., description="Model name to use for generation") prompt: str = Field(..., description="Input prompt for text generation") temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") stream: Optional[bool] = Field(False, description="Whether to stream the response") class ModelPullRequest(BaseModel): model: str = Field(..., description="Model name to pull (e.g., 'llama2:7b')") class ChatResponse(BaseModel): model: str response: str done: bool total_duration: Optional[int] = None load_duration: Optional[int] = None prompt_eval_count: Optional[int] = None eval_count: Optional[int] = None class GenerateResponse(BaseModel): model: str response: str done: bool total_duration: Optional[int] = None load_duration: Optional[int] = None prompt_eval_count: Optional[int] = None eval_count: Optional[int] = None # HTTP client for Ollama API async def get_ollama_client(): return httpx.AsyncClient(timeout=300.0) # 5 minute timeout @app.get("/health") async def health_check(): """Health check endpoint""" try: async with await get_ollama_client() as client: response = await client.get(f"{OLLAMA_BASE_URL}/api/version") if response.status_code == 200: return { "status": "healthy", "ollama_status": "running", "ollama_version": response.json(), "timestamp": time.time() } else: return { "status": "degraded", "ollama_status": "error", "error": f"Ollama returned status {response.status_code}", "timestamp": time.time() } except Exception as e: logger.error(f"Health check failed: {e}") return { "status": "unhealthy", "ollama_status": "unreachable", "error": str(e), "timestamp": time.time() } @app.get("/models") async def list_models(): """List available models""" try: async with await get_ollama_client() as client: response = await client.get(f"{OLLAMA_BASE_URL}/api/tags") response.raise_for_status() return response.json() except httpx.HTTPError as e: logger.error(f"Failed to list models: {e}") raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}") @app.post("/models/pull") async def pull_model(request: ModelPullRequest, background_tasks: BackgroundTasks): """Pull a model from Ollama registry""" try: async with await get_ollama_client() as client: # Start the pull request pull_data = {"name": request.model} response = await client.post( f"{OLLAMA_BASE_URL}/api/pull", json=pull_data, timeout=1800.0 # 30 minute timeout for model pulling ) if response.status_code == 200: return { "status": "success", "message": f"Successfully initiated pull for model '{request.model}'", "model": request.model } else: error_detail = response.text logger.error(f"Failed to pull model: {error_detail}") raise HTTPException( status_code=response.status_code, detail=f"Failed to pull model: {error_detail}" ) except httpx.TimeoutException: raise HTTPException( status_code=408, detail="Model pull request timed out. Large models may take longer to download." ) except Exception as e: logger.error(f"Error pulling model: {e}") raise HTTPException(status_code=500, detail=f"Error pulling model: {str(e)}") @app.delete("/models/{model_name}") async def delete_model(model_name: str): """Delete a model""" try: async with await get_ollama_client() as client: response = await client.delete(f"{OLLAMA_BASE_URL}/api/delete", json={"name": model_name}) response.raise_for_status() return {"status": "success", "message": f"Model '{model_name}' deleted successfully"} except httpx.HTTPError as e: logger.error(f"Failed to delete model: {e}") raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") @app.post("/chat", response_model=ChatResponse) async def chat_with_model(request: ChatRequest): """Chat with a model""" try: # Convert messages to Ollama format chat_data = { "model": request.model, "messages": [{"role": msg.role, "content": msg.content} for msg in request.messages], "stream": request.stream, "options": { "temperature": request.temperature, "top_p": request.top_p, "num_predict": request.max_tokens } } async with await get_ollama_client() as client: response = await client.post( f"{OLLAMA_BASE_URL}/api/chat", json=chat_data, timeout=300.0 ) response.raise_for_status() result = response.json() return ChatResponse( model=result.get("model", request.model), response=result.get("message", {}).get("content", ""), done=result.get("done", True), total_duration=result.get("total_duration"), load_duration=result.get("load_duration"), prompt_eval_count=result.get("prompt_eval_count"), eval_count=result.get("eval_count") ) except httpx.HTTPError as e: logger.error(f"Chat request failed: {e}") if e.response.status_code == 404: raise HTTPException( status_code=404, detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull" ) raise HTTPException(status_code=500, detail=f"Chat request failed: {str(e)}") except Exception as e: logger.error(f"Unexpected error in chat: {e}") raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") @app.post("/generate", response_model=GenerateResponse) async def generate_text(request: GenerateRequest): """Generate text completion""" try: generate_data = { "model": request.model, "prompt": request.prompt, "stream": request.stream, "options": { "temperature": request.temperature, "top_p": request.top_p, "num_predict": request.max_tokens } } async with await get_ollama_client() as client: response = await client.post( f"{OLLAMA_BASE_URL}/api/generate", json=generate_data, timeout=300.0 ) response.raise_for_status() result = response.json() return GenerateResponse( model=result.get("model", request.model), response=result.get("response", ""), done=result.get("done", True), total_duration=result.get("total_duration"), load_duration=result.get("load_duration"), prompt_eval_count=result.get("prompt_eval_count"), eval_count=result.get("eval_count") ) except httpx.HTTPError as e: logger.error(f"Generate request failed: {e}") if e.response.status_code == 404: raise HTTPException( status_code=404, detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull" ) raise HTTPException(status_code=500, detail=f"Generate request failed: {str(e)}") except Exception as e: logger.error(f"Unexpected error in generate: {e}") raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "Ollama API Server", "version": "1.0.0", "endpoints": { "health": "/health", "models": "/models", "pull_model": "/models/pull", "chat": "/chat", "generate": "/generate", "docs": "/docs" }, "status": "running" } if __name__ == "__main__": import uvicorn logger.info("Starting Ollama API server...") uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")