from fastapi import FastAPI, HTTPException from pydantic import BaseModel, Field from typing import Optional import httpx import logging import time # Configure logging logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) # FastAPI app app = FastAPI( title="Ollama Generate API", description="Simple REST API for Ollama text generation", version="1.0.0", docs_url="/docs", redoc_url="/redoc" ) # Ollama server configuration OLLAMA_BASE_URL = "http://localhost:11434" # Pydantic models class GenerateRequest(BaseModel): model: str = Field(..., description="Model name to use for generation") prompt: str = Field(..., description="Input prompt for text generation") temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") class GenerateResponse(BaseModel): model: str response: str done: bool total_duration: Optional[int] = None load_duration: Optional[int] = None prompt_eval_count: Optional[int] = None eval_count: Optional[int] = None # HTTP client for Ollama API async def get_ollama_client(): return httpx.AsyncClient(timeout=300.0) # 5 minute timeout @app.get("/health") async def health_check(): """Health check endpoint""" try: async with await get_ollama_client() as client: response = await client.get(f"{OLLAMA_BASE_URL}/api/version") if response.status_code == 200: return { "status": "healthy", "ollama_status": "running", "timestamp": time.time() } else: return { "status": "degraded", "ollama_status": "error", "error": f"Ollama returned status {response.status_code}", "timestamp": time.time() } except Exception as e: logger.error(f"Health check failed: {e}") return { "status": "unhealthy", "ollama_status": "unreachable", "error": str(e), "timestamp": time.time() } @app.post("/generate", response_model=GenerateResponse) async def generate_text(request: GenerateRequest): """Generate text completion using Ollama""" try: generate_data = { "model": request.model, "prompt": request.prompt, "stream": False, # Always non-streaming for simplicity "options": { "temperature": request.temperature, "top_p": request.top_p, "num_predict": request.max_tokens } } logger.info(f"Generating text with model: {request.model}") async with await get_ollama_client() as client: response = await client.post( f"{OLLAMA_BASE_URL}/api/generate", json=generate_data, timeout=300.0 ) if response.status_code == 404: raise HTTPException( status_code=404, detail=f"Model '{request.model}' not found. Make sure the model is pulled and available." ) response.raise_for_status() result = response.json() return GenerateResponse( model=result.get("model", request.model), response=result.get("response", ""), done=result.get("done", True), total_duration=result.get("total_duration"), load_duration=result.get("load_duration"), prompt_eval_count=result.get("prompt_eval_count"), eval_count=result.get("eval_count") ) except httpx.HTTPError as e: logger.error(f"Generate request failed: Status {e.response.status_code}") if e.response.status_code == 404: raise HTTPException( status_code=404, detail=f"Model '{request.model}' not found. Make sure it's installed." ) raise HTTPException( status_code=500, detail=f"Generation failed: {str(e)}" ) except httpx.TimeoutException: logger.error("Generate request timed out") raise HTTPException( status_code=408, detail="Request timed out. Try with a shorter prompt or smaller max_tokens." ) except Exception as e: logger.error(f"Unexpected error in generate: {e}") raise HTTPException( status_code=500, detail=f"Unexpected error: {str(e)}" ) @app.get("/") async def root(): """Root endpoint with API information""" return { "message": "Ollama Generate API", "version": "1.0.0", "endpoints": { "health": "/health - Check if Ollama is running", "generate": "/generate - Generate text using Ollama models", "docs": "/docs - API documentation" }, "usage": { "example": { "url": "/generate", "method": "POST", "body": { "model": "tinyllama", "prompt": "Hello, how are you?", "temperature": 0.7, "max_tokens": 100 } } }, "status": "running" } if __name__ == "__main__": import uvicorn logger.info("Starting Ollama Generate API server...") uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")