brendon-ai commited on
Commit
c0fd7e0
·
verified ·
1 Parent(s): 4974b02

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +48 -152
app.py CHANGED
@@ -1,11 +1,9 @@
1
- from fastapi import FastAPI, HTTPException, BackgroundTasks
2
  from pydantic import BaseModel, Field
3
- from typing import List, Optional, Dict, Any
4
  import httpx
5
- import asyncio
6
  import logging
7
  import time
8
- import json
9
 
10
  # Configure logging
11
  logging.basicConfig(level=logging.INFO)
@@ -13,8 +11,8 @@ logger = logging.getLogger(__name__)
13
 
14
  # FastAPI app
15
  app = FastAPI(
16
- title="Ollama API Server",
17
- description="REST API for running Ollama models",
18
  version="1.0.0",
19
  docs_url="/docs",
20
  redoc_url="/redoc"
@@ -24,37 +22,12 @@ app = FastAPI(
24
  OLLAMA_BASE_URL = "http://localhost:11434"
25
 
26
  # Pydantic models
27
- class ChatMessage(BaseModel):
28
- role: str = Field(..., description="Role of the message sender (user, assistant, system)")
29
- content: str = Field(..., description="Content of the message")
30
-
31
- class ChatRequest(BaseModel):
32
- model: str = Field(..., description="Model name to use for chat")
33
- messages: List[ChatMessage] = Field(..., description="List of chat messages")
34
- temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
35
- top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
36
- max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
37
- stream: Optional[bool] = Field(False, description="Whether to stream the response")
38
-
39
  class GenerateRequest(BaseModel):
40
  model: str = Field(..., description="Model name to use for generation")
41
  prompt: str = Field(..., description="Input prompt for text generation")
42
  temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
43
  top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
44
  max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
45
- stream: Optional[bool] = Field(False, description="Whether to stream the response")
46
-
47
- class ModelPullRequest(BaseModel):
48
- model: str = Field(..., description="Model name to pull (e.g., 'llama2:7b')")
49
-
50
- class ChatResponse(BaseModel):
51
- model: str
52
- response: str
53
- done: bool
54
- total_duration: Optional[int] = None
55
- load_duration: Optional[int] = None
56
- prompt_eval_count: Optional[int] = None
57
- eval_count: Optional[int] = None
58
 
59
  class GenerateResponse(BaseModel):
60
  model: str
@@ -79,7 +52,6 @@ async def health_check():
79
  return {
80
  "status": "healthy",
81
  "ollama_status": "running",
82
- "ollama_version": response.json(),
83
  "timestamp": time.time()
84
  }
85
  else:
@@ -98,120 +70,14 @@ async def health_check():
98
  "timestamp": time.time()
99
  }
100
 
101
- @app.get("/models")
102
- async def list_models():
103
- """List available models"""
104
- try:
105
- async with await get_ollama_client() as client:
106
- response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
107
- response.raise_for_status()
108
- return response.json()
109
- except httpx.HTTPError as e:
110
- logger.error(f"Failed to list models: {e}")
111
- raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")
112
-
113
- @app.post("/models/pull")
114
- async def pull_model(request: ModelPullRequest, background_tasks: BackgroundTasks):
115
- """Pull a model from Ollama registry"""
116
- try:
117
- async with await get_ollama_client() as client:
118
- # Start the pull request
119
- pull_data = {"name": request.model}
120
- response = await client.post(
121
- f"{OLLAMA_BASE_URL}/api/pull",
122
- json=pull_data,
123
- timeout=1800.0 # 30 minute timeout for model pulling
124
- )
125
-
126
- if response.status_code == 200:
127
- return {
128
- "status": "success",
129
- "message": f"Successfully initiated pull for model '{request.model}'",
130
- "model": request.model
131
- }
132
- else:
133
- error_detail = response.text
134
- logger.error(f"Failed to pull model: {error_detail}")
135
- raise HTTPException(
136
- status_code=response.status_code,
137
- detail=f"Failed to pull model: {error_detail}"
138
- )
139
- except httpx.TimeoutException:
140
- raise HTTPException(
141
- status_code=408,
142
- detail="Model pull request timed out. Large models may take longer to download."
143
- )
144
- except Exception as e:
145
- logger.error(f"Error pulling model: {e}")
146
- raise HTTPException(status_code=500, detail=f"Error pulling model: {str(e)}")
147
-
148
- @app.delete("/models/{model_name}")
149
- async def delete_model(model_name: str):
150
- """Delete a model"""
151
- try:
152
- async with await get_ollama_client() as client:
153
- response = await client.delete(f"{OLLAMA_BASE_URL}/api/delete", json={"name": model_name})
154
- response.raise_for_status()
155
- return {"status": "success", "message": f"Model '{model_name}' deleted successfully"}
156
- except httpx.HTTPError as e:
157
- logger.error(f"Failed to delete model: {e}")
158
- raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
159
-
160
- @app.post("/chat", response_model=ChatResponse)
161
- async def chat_with_model(request: ChatRequest):
162
- """Chat with a model"""
163
- try:
164
- # Convert messages to Ollama format
165
- chat_data = {
166
- "model": request.model,
167
- "messages": [{"role": msg.role, "content": msg.content} for msg in request.messages],
168
- "stream": request.stream,
169
- "options": {
170
- "temperature": request.temperature,
171
- "top_p": request.top_p,
172
- "num_predict": request.max_tokens
173
- }
174
- }
175
-
176
- async with await get_ollama_client() as client:
177
- response = await client.post(
178
- f"{OLLAMA_BASE_URL}/api/chat",
179
- json=chat_data,
180
- timeout=300.0
181
- )
182
- response.raise_for_status()
183
- result = response.json()
184
-
185
- return ChatResponse(
186
- model=result.get("model", request.model),
187
- response=result.get("message", {}).get("content", ""),
188
- done=result.get("done", True),
189
- total_duration=result.get("total_duration"),
190
- load_duration=result.get("load_duration"),
191
- prompt_eval_count=result.get("prompt_eval_count"),
192
- eval_count=result.get("eval_count")
193
- )
194
-
195
- except httpx.HTTPError as e:
196
- logger.error(f"Chat request failed: {e}")
197
- if e.response.status_code == 404:
198
- raise HTTPException(
199
- status_code=404,
200
- detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
201
- )
202
- raise HTTPException(status_code=500, detail=f"Chat request failed: {str(e)}")
203
- except Exception as e:
204
- logger.error(f"Unexpected error in chat: {e}")
205
- raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
206
-
207
  @app.post("/generate", response_model=GenerateResponse)
208
  async def generate_text(request: GenerateRequest):
209
- """Generate text completion"""
210
  try:
211
  generate_data = {
212
  "model": request.model,
213
  "prompt": request.prompt,
214
- "stream": request.stream,
215
  "options": {
216
  "temperature": request.temperature,
217
  "top_p": request.top_p,
@@ -219,12 +85,21 @@ async def generate_text(request: GenerateRequest):
219
  }
220
  }
221
 
 
 
222
  async with await get_ollama_client() as client:
223
  response = await client.post(
224
  f"{OLLAMA_BASE_URL}/api/generate",
225
  json=generate_data,
226
  timeout=300.0
227
  )
 
 
 
 
 
 
 
228
  response.raise_for_status()
229
  result = response.json()
230
 
@@ -239,35 +114,56 @@ async def generate_text(request: GenerateRequest):
239
  )
240
 
241
  except httpx.HTTPError as e:
242
- logger.error(f"Generate request failed: {e}")
243
  if e.response.status_code == 404:
244
  raise HTTPException(
245
  status_code=404,
246
- detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
247
  )
248
- raise HTTPException(status_code=500, detail=f"Generate request failed: {str(e)}")
 
 
 
 
 
 
 
 
 
249
  except Exception as e:
250
  logger.error(f"Unexpected error in generate: {e}")
251
- raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
 
 
 
252
 
253
  @app.get("/")
254
  async def root():
255
  """Root endpoint with API information"""
256
  return {
257
- "message": "Ollama API Server",
258
  "version": "1.0.0",
259
  "endpoints": {
260
- "health": "/health",
261
- "models": "/models",
262
- "pull_model": "/models/pull",
263
- "chat": "/chat",
264
- "generate": "/generate",
265
- "docs": "/docs"
 
 
 
 
 
 
 
 
 
266
  },
267
  "status": "running"
268
  }
269
 
270
  if __name__ == "__main__":
271
  import uvicorn
272
- logger.info("Starting Ollama API server...")
273
  uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
1
+ from fastapi import FastAPI, HTTPException
2
  from pydantic import BaseModel, Field
3
+ from typing import Optional
4
  import httpx
 
5
  import logging
6
  import time
 
7
 
8
  # Configure logging
9
  logging.basicConfig(level=logging.INFO)
 
11
 
12
  # FastAPI app
13
  app = FastAPI(
14
+ title="Ollama Generate API",
15
+ description="Simple REST API for Ollama text generation",
16
  version="1.0.0",
17
  docs_url="/docs",
18
  redoc_url="/redoc"
 
22
  OLLAMA_BASE_URL = "http://localhost:11434"
23
 
24
  # Pydantic models
 
 
 
 
 
 
 
 
 
 
 
 
25
  class GenerateRequest(BaseModel):
26
  model: str = Field(..., description="Model name to use for generation")
27
  prompt: str = Field(..., description="Input prompt for text generation")
28
  temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
29
  top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
30
  max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
 
 
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  class GenerateResponse(BaseModel):
33
  model: str
 
52
  return {
53
  "status": "healthy",
54
  "ollama_status": "running",
 
55
  "timestamp": time.time()
56
  }
57
  else:
 
70
  "timestamp": time.time()
71
  }
72
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
73
  @app.post("/generate", response_model=GenerateResponse)
74
  async def generate_text(request: GenerateRequest):
75
+ """Generate text completion using Ollama"""
76
  try:
77
  generate_data = {
78
  "model": request.model,
79
  "prompt": request.prompt,
80
+ "stream": False, # Always non-streaming for simplicity
81
  "options": {
82
  "temperature": request.temperature,
83
  "top_p": request.top_p,
 
85
  }
86
  }
87
 
88
+ logger.info(f"Generating text with model: {request.model}")
89
+
90
  async with await get_ollama_client() as client:
91
  response = await client.post(
92
  f"{OLLAMA_BASE_URL}/api/generate",
93
  json=generate_data,
94
  timeout=300.0
95
  )
96
+
97
+ if response.status_code == 404:
98
+ raise HTTPException(
99
+ status_code=404,
100
+ detail=f"Model '{request.model}' not found. Make sure the model is pulled and available."
101
+ )
102
+
103
  response.raise_for_status()
104
  result = response.json()
105
 
 
114
  )
115
 
116
  except httpx.HTTPError as e:
117
+ logger.error(f"Generate request failed: Status {e.response.status_code}")
118
  if e.response.status_code == 404:
119
  raise HTTPException(
120
  status_code=404,
121
+ detail=f"Model '{request.model}' not found. Make sure it's installed."
122
  )
123
+ raise HTTPException(
124
+ status_code=500,
125
+ detail=f"Generation failed: {str(e)}"
126
+ )
127
+ except httpx.TimeoutException:
128
+ logger.error("Generate request timed out")
129
+ raise HTTPException(
130
+ status_code=408,
131
+ detail="Request timed out. Try with a shorter prompt or smaller max_tokens."
132
+ )
133
  except Exception as e:
134
  logger.error(f"Unexpected error in generate: {e}")
135
+ raise HTTPException(
136
+ status_code=500,
137
+ detail=f"Unexpected error: {str(e)}"
138
+ )
139
 
140
  @app.get("/")
141
  async def root():
142
  """Root endpoint with API information"""
143
  return {
144
+ "message": "Ollama Generate API",
145
  "version": "1.0.0",
146
  "endpoints": {
147
+ "health": "/health - Check if Ollama is running",
148
+ "generate": "/generate - Generate text using Ollama models",
149
+ "docs": "/docs - API documentation"
150
+ },
151
+ "usage": {
152
+ "example": {
153
+ "url": "/generate",
154
+ "method": "POST",
155
+ "body": {
156
+ "model": "tinyllama",
157
+ "prompt": "Hello, how are you?",
158
+ "temperature": 0.7,
159
+ "max_tokens": 100
160
+ }
161
+ }
162
  },
163
  "status": "running"
164
  }
165
 
166
  if __name__ == "__main__":
167
  import uvicorn
168
+ logger.info("Starting Ollama Generate API server...")
169
  uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")