Spaces:
Sleeping
Sleeping
File size: 5,836 Bytes
c0fd7e0 a285a66 c0fd7e0 a285a66 de7dbb0 a285a66 de7dbb0 a285a66 de7dbb0 a285a66 de7dbb0 c0fd7e0 a285a66 de7dbb0 a285a66 665d3e9 a285a66 1d0ba1f a285a66 cc2745c a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 c0fd7e0 de7dbb0 a285a66 c0fd7e0 a285a66 c0fd7e0 a285a66 de7dbb0 c0fd7e0 a285a66 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional
import httpx
import logging
import time
# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# FastAPI app
app = FastAPI(
title="Ollama Generate API",
description="Simple REST API for Ollama text generation",
version="1.0.0",
docs_url="/docs",
redoc_url="/redoc"
)
# Ollama server configuration
OLLAMA_BASE_URL = "http://localhost:11434"
# Pydantic models
class GenerateRequest(BaseModel):
model: str = Field(..., description="Model name to use for generation")
prompt: str = Field(..., description="Input prompt for text generation")
temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
class GenerateResponse(BaseModel):
model: str
response: str
done: bool
total_duration: Optional[int] = None
load_duration: Optional[int] = None
prompt_eval_count: Optional[int] = None
eval_count: Optional[int] = None
# HTTP client for Ollama API
async def get_ollama_client():
return httpx.AsyncClient(timeout=300.0) # 5 minute timeout
@app.get("/health")
async def health_check():
"""Health check endpoint"""
try:
async with await get_ollama_client() as client:
response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
if response.status_code == 200:
return {
"status": "healthy",
"ollama_status": "running",
"timestamp": time.time()
}
else:
return {
"status": "degraded",
"ollama_status": "error",
"error": f"Ollama returned status {response.status_code}",
"timestamp": time.time()
}
except Exception as e:
logger.error(f"Health check failed: {e}")
return {
"status": "unhealthy",
"ollama_status": "unreachable",
"error": str(e),
"timestamp": time.time()
}
@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
"""Generate text completion using Ollama"""
try:
generate_data = {
"model": request.model,
"prompt": request.prompt,
"stream": False, # Always non-streaming for simplicity
"options": {
"temperature": request.temperature,
"top_p": request.top_p,
"num_predict": request.max_tokens
}
}
logger.info(f"Generating text with model: {request.model}")
async with await get_ollama_client() as client:
response = await client.post(
f"{OLLAMA_BASE_URL}/api/generate",
json=generate_data,
timeout=300.0
)
if response.status_code == 404:
raise HTTPException(
status_code=404,
detail=f"Model '{request.model}' not found. Make sure the model is pulled and available."
)
response.raise_for_status()
result = response.json()
return GenerateResponse(
model=result.get("model", request.model),
response=result.get("response", ""),
done=result.get("done", True),
total_duration=result.get("total_duration"),
load_duration=result.get("load_duration"),
prompt_eval_count=result.get("prompt_eval_count"),
eval_count=result.get("eval_count")
)
except httpx.HTTPError as e:
logger.error(f"Generate request failed: Status {e.response.status_code}")
if e.response.status_code == 404:
raise HTTPException(
status_code=404,
detail=f"Model '{request.model}' not found. Make sure it's installed."
)
raise HTTPException(
status_code=500,
detail=f"Generation failed: {str(e)}"
)
except httpx.TimeoutException:
logger.error("Generate request timed out")
raise HTTPException(
status_code=408,
detail="Request timed out. Try with a shorter prompt or smaller max_tokens."
)
except Exception as e:
logger.error(f"Unexpected error in generate: {e}")
raise HTTPException(
status_code=500,
detail=f"Unexpected error: {str(e)}"
)
@app.get("/")
async def root():
"""Root endpoint with API information"""
return {
"message": "Ollama Generate API",
"version": "1.0.0",
"endpoints": {
"health": "/health - Check if Ollama is running",
"generate": "/generate - Generate text using Ollama models",
"docs": "/docs - API documentation"
},
"usage": {
"example": {
"url": "/generate",
"method": "POST",
"body": {
"model": "tinyllama",
"prompt": "Hello, how are you?",
"temperature": 0.7,
"max_tokens": 100
}
}
},
"status": "running"
}
if __name__ == "__main__":
import uvicorn
logger.info("Starting Ollama Generate API server...")
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |