Spaces:
Running
Running
from fastapi import FastAPI, HTTPException, BackgroundTasks | |
from pydantic import BaseModel, Field | |
from typing import List, Optional, Dict, Any | |
import httpx | |
import asyncio | |
import logging | |
import time | |
import json | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# FastAPI app | |
app = FastAPI( | |
title="Ollama API Server", | |
description="REST API for running Ollama models", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Ollama server configuration | |
OLLAMA_BASE_URL = "http://localhost:11434" | |
# Pydantic models | |
class ChatMessage(BaseModel): | |
role: str = Field(..., description="Role of the message sender (user, assistant, system)") | |
content: str = Field(..., description="Content of the message") | |
class ChatRequest(BaseModel): | |
model: str = Field(..., description="Model name to use for chat") | |
messages: List[ChatMessage] = Field(..., description="List of chat messages") | |
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") | |
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") | |
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") | |
stream: Optional[bool] = Field(False, description="Whether to stream the response") | |
class GenerateRequest(BaseModel): | |
model: str = Field(..., description="Model name to use for generation") | |
prompt: str = Field(..., description="Input prompt for text generation") | |
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") | |
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") | |
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") | |
stream: Optional[bool] = Field(False, description="Whether to stream the response") | |
class ModelPullRequest(BaseModel): | |
model: str = Field(..., description="Model name to pull (e.g., 'llama2:7b')") | |
class ChatResponse(BaseModel): | |
model: str | |
response: str | |
done: bool | |
total_duration: Optional[int] = None | |
load_duration: Optional[int] = None | |
prompt_eval_count: Optional[int] = None | |
eval_count: Optional[int] = None | |
class GenerateResponse(BaseModel): | |
model: str | |
response: str | |
done: bool | |
total_duration: Optional[int] = None | |
load_duration: Optional[int] = None | |
prompt_eval_count: Optional[int] = None | |
eval_count: Optional[int] = None | |
# HTTP client for Ollama API | |
async def get_ollama_client(): | |
return httpx.AsyncClient(timeout=300.0) # 5 minute timeout | |
async def health_check(): | |
"""Health check endpoint""" | |
try: | |
async with await get_ollama_client() as client: | |
response = await client.get(f"{OLLAMA_BASE_URL}/api/version") | |
if response.status_code == 200: | |
return { | |
"status": "healthy", | |
"ollama_status": "running", | |
"ollama_version": response.json(), | |
"timestamp": time.time() | |
} | |
else: | |
return { | |
"status": "degraded", | |
"ollama_status": "error", | |
"error": f"Ollama returned status {response.status_code}", | |
"timestamp": time.time() | |
} | |
except Exception as e: | |
logger.error(f"Health check failed: {e}") | |
return { | |
"status": "unhealthy", | |
"ollama_status": "unreachable", | |
"error": str(e), | |
"timestamp": time.time() | |
} | |
async def list_models(): | |
"""List available models""" | |
try: | |
async with await get_ollama_client() as client: | |
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags") | |
response.raise_for_status() | |
return response.json() | |
except httpx.HTTPError as e: | |
logger.error(f"Failed to list models: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}") | |
async def pull_model(request: ModelPullRequest, background_tasks: BackgroundTasks): | |
"""Pull a model from Ollama registry""" | |
try: | |
async with await get_ollama_client() as client: | |
# Start the pull request | |
pull_data = {"name": request.model} | |
response = await client.post( | |
f"{OLLAMA_BASE_URL}/api/pull", | |
json=pull_data, | |
timeout=1800.0 # 30 minute timeout for model pulling | |
) | |
if response.status_code == 200: | |
return { | |
"status": "success", | |
"message": f"Successfully initiated pull for model '{request.model}'", | |
"model": request.model | |
} | |
else: | |
error_detail = response.text | |
logger.error(f"Failed to pull model: {error_detail}") | |
raise HTTPException( | |
status_code=response.status_code, | |
detail=f"Failed to pull model: {error_detail}" | |
) | |
except httpx.TimeoutException: | |
raise HTTPException( | |
status_code=408, | |
detail="Model pull request timed out. Large models may take longer to download." | |
) | |
except Exception as e: | |
logger.error(f"Error pulling model: {e}") | |
raise HTTPException(status_code=500, detail=f"Error pulling model: {str(e)}") | |
async def delete_model(model_name: str): | |
"""Delete a model""" | |
try: | |
async with await get_ollama_client() as client: | |
response = await client.delete(f"{OLLAMA_BASE_URL}/api/delete", json={"name": model_name}) | |
response.raise_for_status() | |
return {"status": "success", "message": f"Model '{model_name}' deleted successfully"} | |
except httpx.HTTPError as e: | |
logger.error(f"Failed to delete model: {e}") | |
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}") | |
async def chat_with_model(request: ChatRequest): | |
"""Chat with a model""" | |
try: | |
# Convert messages to Ollama format | |
chat_data = { | |
"model": request.model, | |
"messages": [{"role": msg.role, "content": msg.content} for msg in request.messages], | |
"stream": request.stream, | |
"options": { | |
"temperature": request.temperature, | |
"top_p": request.top_p, | |
"num_predict": request.max_tokens | |
} | |
} | |
async with await get_ollama_client() as client: | |
response = await client.post( | |
f"{OLLAMA_BASE_URL}/api/chat", | |
json=chat_data, | |
timeout=300.0 | |
) | |
response.raise_for_status() | |
result = response.json() | |
return ChatResponse( | |
model=result.get("model", request.model), | |
response=result.get("message", {}).get("content", ""), | |
done=result.get("done", True), | |
total_duration=result.get("total_duration"), | |
load_duration=result.get("load_duration"), | |
prompt_eval_count=result.get("prompt_eval_count"), | |
eval_count=result.get("eval_count") | |
) | |
except httpx.HTTPError as e: | |
logger.error(f"Chat request failed: {e}") | |
if e.response.status_code == 404: | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull" | |
) | |
raise HTTPException(status_code=500, detail=f"Chat request failed: {str(e)}") | |
except Exception as e: | |
logger.error(f"Unexpected error in chat: {e}") | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
async def generate_text(request: GenerateRequest): | |
"""Generate text completion""" | |
try: | |
generate_data = { | |
"model": request.model, | |
"prompt": request.prompt, | |
"stream": request.stream, | |
"options": { | |
"temperature": request.temperature, | |
"top_p": request.top_p, | |
"num_predict": request.max_tokens | |
} | |
} | |
async with await get_ollama_client() as client: | |
response = await client.post( | |
f"{OLLAMA_BASE_URL}/api/generate", | |
json=generate_data, | |
timeout=300.0 | |
) | |
response.raise_for_status() | |
result = response.json() | |
return GenerateResponse( | |
model=result.get("model", request.model), | |
response=result.get("response", ""), | |
done=result.get("done", True), | |
total_duration=result.get("total_duration"), | |
load_duration=result.get("load_duration"), | |
prompt_eval_count=result.get("prompt_eval_count"), | |
eval_count=result.get("eval_count") | |
) | |
except httpx.HTTPError as e: | |
logger.error(f"Generate request failed: {e}") | |
if e.response.status_code == 404: | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull" | |
) | |
raise HTTPException(status_code=500, detail=f"Generate request failed: {str(e)}") | |
except Exception as e: | |
logger.error(f"Unexpected error in generate: {e}") | |
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}") | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"message": "Ollama API Server", | |
"version": "1.0.0", | |
"endpoints": { | |
"health": "/health", | |
"models": "/models", | |
"pull_model": "/models/pull", | |
"chat": "/chat", | |
"generate": "/generate", | |
"docs": "/docs" | |
}, | |
"status": "running" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting Ollama API server...") | |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |