File size: 5,836 Bytes
c0fd7e0
a285a66
c0fd7e0
a285a66
de7dbb0
a285a66
de7dbb0
a285a66
de7dbb0
 
 
a285a66
de7dbb0
c0fd7e0
 
a285a66
 
 
de7dbb0
 
a285a66
 
665d3e9
a285a66
 
 
 
 
 
 
1d0ba1f
a285a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2745c
a285a66
 
 
 
 
 
 
 
 
 
c0fd7e0
a285a66
 
 
 
c0fd7e0
a285a66
 
 
 
 
 
 
c0fd7e0
 
a285a66
 
 
 
 
 
c0fd7e0
 
 
 
 
 
 
a285a66
 
 
 
 
 
 
 
 
 
 
 
 
 
c0fd7e0
a285a66
 
 
c0fd7e0
a285a66
c0fd7e0
 
 
 
 
 
 
 
 
 
a285a66
 
c0fd7e0
 
 
 
de7dbb0
a285a66
 
 
 
c0fd7e0
a285a66
 
c0fd7e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a285a66
 
 
de7dbb0
 
 
c0fd7e0
a285a66
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field
from typing import Optional
import httpx
import logging
import time

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# FastAPI app
app = FastAPI(
    title="Ollama Generate API",
    description="Simple REST API for Ollama text generation",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Ollama server configuration
OLLAMA_BASE_URL = "http://localhost:11434"

# Pydantic models
class GenerateRequest(BaseModel):
    model: str = Field(..., description="Model name to use for generation")
    prompt: str = Field(..., description="Input prompt for text generation")
    temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
    top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
    max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")

class GenerateResponse(BaseModel):
    model: str
    response: str
    done: bool
    total_duration: Optional[int] = None
    load_duration: Optional[int] = None
    prompt_eval_count: Optional[int] = None
    eval_count: Optional[int] = None

# HTTP client for Ollama API
async def get_ollama_client():
    return httpx.AsyncClient(timeout=300.0)  # 5 minute timeout

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    try:
        async with await get_ollama_client() as client:
            response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
            if response.status_code == 200:
                return {
                    "status": "healthy",
                    "ollama_status": "running",
                    "timestamp": time.time()
                }
            else:
                return {
                    "status": "degraded",
                    "ollama_status": "error",
                    "error": f"Ollama returned status {response.status_code}",
                    "timestamp": time.time()
                }
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        return {
            "status": "unhealthy",
            "ollama_status": "unreachable",
            "error": str(e),
            "timestamp": time.time()
        }

@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
    """Generate text completion using Ollama"""
    try:
        generate_data = {
            "model": request.model,
            "prompt": request.prompt,
            "stream": False,  # Always non-streaming for simplicity
            "options": {
                "temperature": request.temperature,
                "top_p": request.top_p,
                "num_predict": request.max_tokens
            }
        }
        
        logger.info(f"Generating text with model: {request.model}")
        
        async with await get_ollama_client() as client:
            response = await client.post(
                f"{OLLAMA_BASE_URL}/api/generate",
                json=generate_data,
                timeout=300.0
            )
            
            if response.status_code == 404:
                raise HTTPException(
                    status_code=404,
                    detail=f"Model '{request.model}' not found. Make sure the model is pulled and available."
                )
            
            response.raise_for_status()
            result = response.json()
            
            return GenerateResponse(
                model=result.get("model", request.model),
                response=result.get("response", ""),
                done=result.get("done", True),
                total_duration=result.get("total_duration"),
                load_duration=result.get("load_duration"),
                prompt_eval_count=result.get("prompt_eval_count"),
                eval_count=result.get("eval_count")
            )
            
    except httpx.HTTPError as e:
        logger.error(f"Generate request failed: Status {e.response.status_code}")
        if e.response.status_code == 404:
            raise HTTPException(
                status_code=404,
                detail=f"Model '{request.model}' not found. Make sure it's installed."
            )
        raise HTTPException(
            status_code=500, 
            detail=f"Generation failed: {str(e)}"
        )
    except httpx.TimeoutException:
        logger.error("Generate request timed out")
        raise HTTPException(
            status_code=408,
            detail="Request timed out. Try with a shorter prompt or smaller max_tokens."
        )
    except Exception as e:
        logger.error(f"Unexpected error in generate: {e}")
        raise HTTPException(
            status_code=500, 
            detail=f"Unexpected error: {str(e)}"
        )

@app.get("/")
async def root():
    """Root endpoint with API information"""
    return {
        "message": "Ollama Generate API",
        "version": "1.0.0",
        "endpoints": {
            "health": "/health - Check if Ollama is running",
            "generate": "/generate - Generate text using Ollama models",
            "docs": "/docs - API documentation"
        },
        "usage": {
            "example": {
                "url": "/generate",
                "method": "POST",
                "body": {
                    "model": "tinyllama",
                    "prompt": "Hello, how are you?",
                    "temperature": 0.7,
                    "max_tokens": 100
                }
            }
        },
        "status": "running"
    }

if __name__ == "__main__":
    import uvicorn
    logger.info("Starting Ollama Generate API server...")
    uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")