File size: 10,511 Bytes
a285a66
 
 
 
 
de7dbb0
a285a66
 
de7dbb0
a285a66
de7dbb0
 
 
a285a66
de7dbb0
a285a66
 
 
 
 
de7dbb0
 
a285a66
 
665d3e9
a285a66
 
 
 
665d3e9
a285a66
 
 
 
 
 
 
f97c475
a285a66
 
 
 
 
 
 
f97c475
a285a66
 
f97c475
a285a66
 
 
 
 
 
 
 
1d0ba1f
a285a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2745c
a285a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cc2745c
a285a66
 
cc2745c
a285a66
 
 
1d0ba1f
a285a66
 
 
 
 
 
 
 
 
 
 
de7dbb0
a285a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7dbb0
a285a66
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7dbb0
 
 
a285a66
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
from fastapi import FastAPI, HTTPException, BackgroundTasks
from pydantic import BaseModel, Field
from typing import List, Optional, Dict, Any
import httpx
import asyncio
import logging
import time
import json

# Configure logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

# FastAPI app
app = FastAPI(
    title="Ollama API Server",
    description="REST API for running Ollama models",
    version="1.0.0",
    docs_url="/docs",
    redoc_url="/redoc"
)

# Ollama server configuration
OLLAMA_BASE_URL = "http://localhost:11434"

# Pydantic models
class ChatMessage(BaseModel):
    role: str = Field(..., description="Role of the message sender (user, assistant, system)")
    content: str = Field(..., description="Content of the message")

class ChatRequest(BaseModel):
    model: str = Field(..., description="Model name to use for chat")
    messages: List[ChatMessage] = Field(..., description="List of chat messages")
    temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
    top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
    max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
    stream: Optional[bool] = Field(False, description="Whether to stream the response")

class GenerateRequest(BaseModel):
    model: str = Field(..., description="Model name to use for generation")
    prompt: str = Field(..., description="Input prompt for text generation")
    temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
    top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
    max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
    stream: Optional[bool] = Field(False, description="Whether to stream the response")

class ModelPullRequest(BaseModel):
    model: str = Field(..., description="Model name to pull (e.g., 'llama2:7b')")

class ChatResponse(BaseModel):
    model: str
    response: str
    done: bool
    total_duration: Optional[int] = None
    load_duration: Optional[int] = None
    prompt_eval_count: Optional[int] = None
    eval_count: Optional[int] = None

class GenerateResponse(BaseModel):
    model: str
    response: str
    done: bool
    total_duration: Optional[int] = None
    load_duration: Optional[int] = None
    prompt_eval_count: Optional[int] = None
    eval_count: Optional[int] = None

# HTTP client for Ollama API
async def get_ollama_client():
    return httpx.AsyncClient(timeout=300.0)  # 5 minute timeout

@app.get("/health")
async def health_check():
    """Health check endpoint"""
    try:
        async with await get_ollama_client() as client:
            response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
            if response.status_code == 200:
                return {
                    "status": "healthy",
                    "ollama_status": "running",
                    "ollama_version": response.json(),
                    "timestamp": time.time()
                }
            else:
                return {
                    "status": "degraded",
                    "ollama_status": "error",
                    "error": f"Ollama returned status {response.status_code}",
                    "timestamp": time.time()
                }
    except Exception as e:
        logger.error(f"Health check failed: {e}")
        return {
            "status": "unhealthy",
            "ollama_status": "unreachable",
            "error": str(e),
            "timestamp": time.time()
        }

@app.get("/models")
async def list_models():
    """List available models"""
    try:
        async with await get_ollama_client() as client:
            response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
            response.raise_for_status()
            return response.json()
    except httpx.HTTPError as e:
        logger.error(f"Failed to list models: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")

@app.post("/models/pull")
async def pull_model(request: ModelPullRequest, background_tasks: BackgroundTasks):
    """Pull a model from Ollama registry"""
    try:
        async with await get_ollama_client() as client:
            # Start the pull request
            pull_data = {"name": request.model}
            response = await client.post(
                f"{OLLAMA_BASE_URL}/api/pull",
                json=pull_data,
                timeout=1800.0  # 30 minute timeout for model pulling
            )
            
            if response.status_code == 200:
                return {
                    "status": "success",
                    "message": f"Successfully initiated pull for model '{request.model}'",
                    "model": request.model
                }
            else:
                error_detail = response.text
                logger.error(f"Failed to pull model: {error_detail}")
                raise HTTPException(
                    status_code=response.status_code,
                    detail=f"Failed to pull model: {error_detail}"
                )
    except httpx.TimeoutException:
        raise HTTPException(
            status_code=408,
            detail="Model pull request timed out. Large models may take longer to download."
        )
    except Exception as e:
        logger.error(f"Error pulling model: {e}")
        raise HTTPException(status_code=500, detail=f"Error pulling model: {str(e)}")

@app.delete("/models/{model_name}")
async def delete_model(model_name: str):
    """Delete a model"""
    try:
        async with await get_ollama_client() as client:
            response = await client.delete(f"{OLLAMA_BASE_URL}/api/delete", json={"name": model_name})
            response.raise_for_status()
            return {"status": "success", "message": f"Model '{model_name}' deleted successfully"}
    except httpx.HTTPError as e:
        logger.error(f"Failed to delete model: {e}")
        raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")

@app.post("/chat", response_model=ChatResponse)
async def chat_with_model(request: ChatRequest):
    """Chat with a model"""
    try:
        # Convert messages to Ollama format
        chat_data = {
            "model": request.model,
            "messages": [{"role": msg.role, "content": msg.content} for msg in request.messages],
            "stream": request.stream,
            "options": {
                "temperature": request.temperature,
                "top_p": request.top_p,
                "num_predict": request.max_tokens
            }
        }
        
        async with await get_ollama_client() as client:
            response = await client.post(
                f"{OLLAMA_BASE_URL}/api/chat",
                json=chat_data,
                timeout=300.0
            )
            response.raise_for_status()
            result = response.json()
            
            return ChatResponse(
                model=result.get("model", request.model),
                response=result.get("message", {}).get("content", ""),
                done=result.get("done", True),
                total_duration=result.get("total_duration"),
                load_duration=result.get("load_duration"),
                prompt_eval_count=result.get("prompt_eval_count"),
                eval_count=result.get("eval_count")
            )
            
    except httpx.HTTPError as e:
        logger.error(f"Chat request failed: {e}")
        if e.response.status_code == 404:
            raise HTTPException(
                status_code=404,
                detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
            )
        raise HTTPException(status_code=500, detail=f"Chat request failed: {str(e)}")
    except Exception as e:
        logger.error(f"Unexpected error in chat: {e}")
        raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")

@app.post("/generate", response_model=GenerateResponse)
async def generate_text(request: GenerateRequest):
    """Generate text completion"""
    try:
        generate_data = {
            "model": request.model,
            "prompt": request.prompt,
            "stream": request.stream,
            "options": {
                "temperature": request.temperature,
                "top_p": request.top_p,
                "num_predict": request.max_tokens
            }
        }
        
        async with await get_ollama_client() as client:
            response = await client.post(
                f"{OLLAMA_BASE_URL}/api/generate",
                json=generate_data,
                timeout=300.0
            )
            response.raise_for_status()
            result = response.json()
            
            return GenerateResponse(
                model=result.get("model", request.model),
                response=result.get("response", ""),
                done=result.get("done", True),
                total_duration=result.get("total_duration"),
                load_duration=result.get("load_duration"),
                prompt_eval_count=result.get("prompt_eval_count"),
                eval_count=result.get("eval_count")
            )
            
    except httpx.HTTPError as e:
        logger.error(f"Generate request failed: {e}")
        if e.response.status_code == 404:
            raise HTTPException(
                status_code=404,
                detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
            )
        raise HTTPException(status_code=500, detail=f"Generate request failed: {str(e)}")
    except Exception as e:
        logger.error(f"Unexpected error in generate: {e}")
        raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")

@app.get("/")
async def root():
    """Root endpoint with API information"""
    return {
        "message": "Ollama API Server",
        "version": "1.0.0",
        "endpoints": {
            "health": "/health",
            "models": "/models",
            "pull_model": "/models/pull",
            "chat": "/chat",
            "generate": "/generate",
            "docs": "/docs"
        },
        "status": "running"
    }

if __name__ == "__main__":
    import uvicorn
    logger.info("Starting Ollama API server...")
    uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")