Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException | |
from pydantic import BaseModel, Field | |
from typing import Optional | |
import httpx | |
import logging | |
import time | |
# Configure logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# FastAPI app | |
app = FastAPI( | |
title="Ollama Generate API", | |
description="Simple REST API for Ollama text generation", | |
version="1.0.0", | |
docs_url="/docs", | |
redoc_url="/redoc" | |
) | |
# Ollama server configuration | |
OLLAMA_BASE_URL = "http://localhost:11434" | |
# Pydantic models | |
class GenerateRequest(BaseModel): | |
model: str = Field(..., description="Model name to use for generation") | |
prompt: str = Field(..., description="Input prompt for text generation") | |
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature") | |
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter") | |
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate") | |
class GenerateResponse(BaseModel): | |
model: str | |
response: str | |
done: bool | |
total_duration: Optional[int] = None | |
load_duration: Optional[int] = None | |
prompt_eval_count: Optional[int] = None | |
eval_count: Optional[int] = None | |
# HTTP client for Ollama API | |
async def get_ollama_client(): | |
return httpx.AsyncClient(timeout=300.0) # 5 minute timeout | |
async def health_check(): | |
"""Health check endpoint""" | |
try: | |
async with await get_ollama_client() as client: | |
response = await client.get(f"{OLLAMA_BASE_URL}/api/version") | |
if response.status_code == 200: | |
return { | |
"status": "healthy", | |
"ollama_status": "running", | |
"timestamp": time.time() | |
} | |
else: | |
return { | |
"status": "degraded", | |
"ollama_status": "error", | |
"error": f"Ollama returned status {response.status_code}", | |
"timestamp": time.time() | |
} | |
except Exception as e: | |
logger.error(f"Health check failed: {e}") | |
return { | |
"status": "unhealthy", | |
"ollama_status": "unreachable", | |
"error": str(e), | |
"timestamp": time.time() | |
} | |
async def generate_text(request: GenerateRequest): | |
"""Generate text completion using Ollama""" | |
try: | |
generate_data = { | |
"model": request.model, | |
"prompt": request.prompt, | |
"stream": False, # Always non-streaming for simplicity | |
"options": { | |
"temperature": request.temperature, | |
"top_p": request.top_p, | |
"num_predict": request.max_tokens | |
} | |
} | |
logger.info(f"Generating text with model: {request.model}") | |
async with await get_ollama_client() as client: | |
response = await client.post( | |
f"{OLLAMA_BASE_URL}/api/generate", | |
json=generate_data, | |
timeout=300.0 | |
) | |
if response.status_code == 404: | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{request.model}' not found. Make sure the model is pulled and available." | |
) | |
response.raise_for_status() | |
result = response.json() | |
return GenerateResponse( | |
model=result.get("model", request.model), | |
response=result.get("response", ""), | |
done=result.get("done", True), | |
total_duration=result.get("total_duration"), | |
load_duration=result.get("load_duration"), | |
prompt_eval_count=result.get("prompt_eval_count"), | |
eval_count=result.get("eval_count") | |
) | |
except httpx.HTTPError as e: | |
logger.error(f"Generate request failed: Status {e.response.status_code}") | |
if e.response.status_code == 404: | |
raise HTTPException( | |
status_code=404, | |
detail=f"Model '{request.model}' not found. Make sure it's installed." | |
) | |
raise HTTPException( | |
status_code=500, | |
detail=f"Generation failed: {str(e)}" | |
) | |
except httpx.TimeoutException: | |
logger.error("Generate request timed out") | |
raise HTTPException( | |
status_code=408, | |
detail="Request timed out. Try with a shorter prompt or smaller max_tokens." | |
) | |
except Exception as e: | |
logger.error(f"Unexpected error in generate: {e}") | |
raise HTTPException( | |
status_code=500, | |
detail=f"Unexpected error: {str(e)}" | |
) | |
async def root(): | |
"""Root endpoint with API information""" | |
return { | |
"message": "Ollama Generate API", | |
"version": "1.0.0", | |
"endpoints": { | |
"health": "/health - Check if Ollama is running", | |
"generate": "/generate - Generate text using Ollama models", | |
"docs": "/docs - API documentation" | |
}, | |
"usage": { | |
"example": { | |
"url": "/generate", | |
"method": "POST", | |
"body": { | |
"model": "tinyllama", | |
"prompt": "Hello, how are you?", | |
"temperature": 0.7, | |
"max_tokens": 100 | |
} | |
} | |
}, | |
"status": "running" | |
} | |
if __name__ == "__main__": | |
import uvicorn | |
logger.info("Starting Ollama Generate API server...") | |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |