faq / app.py
brendon-ai's picture
Update app.py
c0fd7e0 verified
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional
import httpx
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(
title="Ollama Generate API",
description="Simple REST API for Ollama text generation",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Ollama server configuration
OLLAMA_BASE_URL = "http://localhost:11434"
# Pydantic models
class GenerateRequest(BaseModel):
model: str = Field(..., description="Model name to use for generation")
prompt: str = Field(..., description="Input prompt for text generation")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
class GenerateResponse(BaseModel):
model: str
response: str
done: bool
total_duration: Optional[int] = None
load_duration: Optional[int] = None
prompt_eval_count: Optional[int] = None
eval_count: Optional[int] = None
# HTTP client for Ollama API
async def get_ollama_client():
return httpx.AsyncClient(timeout=300.0) # 5 minute timeout
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
async with await get_ollama_client() as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
if response.status_code == 200:
return {
"status": "healthy",
"ollama_status": "running",
"timestamp": time.time()
}
else:
return {
"status": "degraded",
"ollama_status": "error",
"error": f"Ollama returned status {response.status_code}",
"timestamp": time.time()
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"ollama_status": "unreachable",
"error": str(e),
"timestamp": time.time()
}
@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
"""Generate text completion using Ollama"""
try:
generate_data = {
"model": request.model,
"prompt": request.prompt,
"stream": False, # Always non-streaming for simplicity
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
"num_predict": request.max_tokens
}
}
logger.info(f"Generating text with model: {request.model}")
async with await get_ollama_client() as client:
response = await client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json=generate_data,
timeout=300.0
)
if response.status_code == 404:
raise HTTPException(
status_code=404,
detail=f"Model '{request.model}' not found. Make sure the model is pulled and available."
)
response.raise_for_status()
result = response.json()
return GenerateResponse(
model=result.get("model", request.model),
response=result.get("response", ""),
done=result.get("done", True),
total_duration=result.get("total_duration"),
load_duration=result.get("load_duration"),
prompt_eval_count=result.get("prompt_eval_count"),
eval_count=result.get("eval_count")
)
except httpx.HTTPError as e:
logger.error(f"Generate request failed: Status {e.response.status_code}")
if e.response.status_code == 404:
raise HTTPException(
status_code=404,
detail=f"Model '{request.model}' not found. Make sure it's installed."
)
raise HTTPException(
status_code=500,
detail=f"Generation failed: {str(e)}"
)
except httpx.TimeoutException:
logger.error("Generate request timed out")
raise HTTPException(
status_code=408,
detail="Request timed out. Try with a shorter prompt or smaller max_tokens."
)
except Exception as e:
logger.error(f"Unexpected error in generate: {e}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Ollama Generate API",
"version": "1.0.0",
"endpoints": {
"health": "/health - Check if Ollama is running",
"generate": "/generate - Generate text using Ollama models",
"docs": "/docs - API documentation"
},
"usage": {
"example": {
"url": "/generate",
"method": "POST",
"body": {
"model": "tinyllama",
"prompt": "Hello, how are you?",
"temperature": 0.7,
"max_tokens": 100
}
}
},
"status": "running"
}
if __name__ == "__main__":
import uvicorn
logger.info("Starting Ollama Generate API server...")
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")