faq / app.py
brendon-ai's picture
Update app.py
a285a66 verified
raw
history blame
10.5 kB
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import httpx
import asyncio
import logging
import time
import json
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(
title="Ollama API Server",
description="REST API for running Ollama models",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Ollama server configuration
OLLAMA_BASE_URL = "http://localhost:11434"
# Pydantic models
class ChatMessage(BaseModel):
role: str = Field(..., description="Role of the message sender (user, assistant, system)")
content: str = Field(..., description="Content of the message")
class ChatRequest(BaseModel):
model: str = Field(..., description="Model name to use for chat")
messages: List[ChatMessage] = Field(..., description="List of chat messages")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Whether to stream the response")
class GenerateRequest(BaseModel):
model: str = Field(..., description="Model name to use for generation")
prompt: str = Field(..., description="Input prompt for text generation")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
stream: Optional[bool] = Field(False, description="Whether to stream the response")
class ModelPullRequest(BaseModel):
model: str = Field(..., description="Model name to pull (e.g., 'llama2:7b')")
class ChatResponse(BaseModel):
model: str
response: str
done: bool
total_duration: Optional[int] = None
load_duration: Optional[int] = None
prompt_eval_count: Optional[int] = None
eval_count: Optional[int] = None
class GenerateResponse(BaseModel):
model: str
response: str
done: bool
total_duration: Optional[int] = None
load_duration: Optional[int] = None
prompt_eval_count: Optional[int] = None
eval_count: Optional[int] = None
# HTTP client for Ollama API
async def get_ollama_client():
return httpx.AsyncClient(timeout=300.0) # 5 minute timeout
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
async with await get_ollama_client() as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
if response.status_code == 200:
return {
"status": "healthy",
"ollama_status": "running",
"ollama_version": response.json(),
"timestamp": time.time()
}
else:
return {
"status": "degraded",
"ollama_status": "error",
"error": f"Ollama returned status {response.status_code}",
"timestamp": time.time()
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"ollama_status": "unreachable",
"error": str(e),
"timestamp": time.time()
}
@app.get("/models")
async def list_models():
"""List available models"""
try:
async with await get_ollama_client() as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
response.raise_for_status()
return response.json()
except httpx.HTTPError as e:
logger.error(f"Failed to list models: {e}")
raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")
@app.post("/models/pull")
async def pull_model(request: ModelPullRequest, background_tasks: BackgroundTasks):
"""Pull a model from Ollama registry"""
try:
async with await get_ollama_client() as client:
# Start the pull request
pull_data = {"name": request.model}
response = await client.post(
f"{OLLAMA_BASE_URL}/api/pull",
json=pull_data,
timeout=1800.0 # 30 minute timeout for model pulling
)
if response.status_code == 200:
return {
"status": "success",
"message": f"Successfully initiated pull for model '{request.model}'",
"model": request.model
}
else:
error_detail = response.text
logger.error(f"Failed to pull model: {error_detail}")
raise HTTPException(
status_code=response.status_code,
detail=f"Failed to pull model: {error_detail}"
)
except httpx.TimeoutException:
raise HTTPException(
status_code=408,
detail="Model pull request timed out. Large models may take longer to download."
)
except Exception as e:
logger.error(f"Error pulling model: {e}")
raise HTTPException(status_code=500, detail=f"Error pulling model: {str(e)}")
@app.delete("/models/{model_name}")
async def delete_model(model_name: str):
"""Delete a model"""
try:
async with await get_ollama_client() as client:
response = await client.delete(f"{OLLAMA_BASE_URL}/api/delete", json={"name": model_name})
response.raise_for_status()
return {"status": "success", "message": f"Model '{model_name}' deleted successfully"}
except httpx.HTTPError as e:
logger.error(f"Failed to delete model: {e}")
raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
@app.post("/chat", response_model=ChatResponse)
async def chat_with_model(request: ChatRequest):
"""Chat with a model"""
try:
# Convert messages to Ollama format
chat_data = {
"model": request.model,
"messages": [{"role": msg.role, "content": msg.content} for msg in request.messages],
"stream": request.stream,
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
"num_predict": request.max_tokens
}
}
async with await get_ollama_client() as client:
response = await client.post(
f"{OLLAMA_BASE_URL}/api/chat",
json=chat_data,
timeout=300.0
)
response.raise_for_status()
result = response.json()
return ChatResponse(
model=result.get("model", request.model),
response=result.get("message", {}).get("content", ""),
done=result.get("done", True),
total_duration=result.get("total_duration"),
load_duration=result.get("load_duration"),
prompt_eval_count=result.get("prompt_eval_count"),
eval_count=result.get("eval_count")
)
except httpx.HTTPError as e:
logger.error(f"Chat request failed: {e}")
if e.response.status_code == 404:
raise HTTPException(
status_code=404,
detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
)
raise HTTPException(status_code=500, detail=f"Chat request failed: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in chat: {e}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
"""Generate text completion"""
try:
generate_data = {
"model": request.model,
"prompt": request.prompt,
"stream": request.stream,
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
"num_predict": request.max_tokens
}
}
async with await get_ollama_client() as client:
response = await client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json=generate_data,
timeout=300.0
)
response.raise_for_status()
result = response.json()
return GenerateResponse(
model=result.get("model", request.model),
response=result.get("response", ""),
done=result.get("done", True),
total_duration=result.get("total_duration"),
load_duration=result.get("load_duration"),
prompt_eval_count=result.get("prompt_eval_count"),
eval_count=result.get("eval_count")
)
except httpx.HTTPError as e:
logger.error(f"Generate request failed: {e}")
if e.response.status_code == 404:
raise HTTPException(
status_code=404,
detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
)
raise HTTPException(status_code=500, detail=f"Generate request failed: {str(e)}")
except Exception as e:
logger.error(f"Unexpected error in generate: {e}")
raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Ollama API Server",
"version": "1.0.0",
"endpoints": {
"health": "/health",
"models": "/models",
"pull_model": "/models/pull",
"chat": "/chat",
"generate": "/generate",
"docs": "/docs"
},
"status": "running"
}
if __name__ == "__main__":
import uvicorn
logger.info("Starting Ollama API server...")
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")