brendon-ai commited on
Commit
a285a66
·
verified ·
1 Parent(s): 40f8c64

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +250 -120
app.py CHANGED
@@ -1,143 +1,273 @@
1
- from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
- from pydantic import BaseModel, ValidationError
3
- from transformers import AutoTokenizer, AutoModelForMaskedLM
4
- import torch
 
5
  import logging
 
 
6
 
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
 
10
  app = FastAPI(
11
- title="NeuroBERT-Tiny Masked Language Model API",
12
- description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
13
- version="1.0.0"
 
 
14
  )
15
 
16
- api_router = APIRouter()
17
-
18
- # --- NeuroBERT-Tiny Model Configuration ---
19
- # Using boltuix/NeuroBERT-Tiny for Masked Language Modeling.
20
- MODEL_NAME = "boltuix/NeuroBERT-Tiny"
21
- # ----------------------------------------
22
-
23
- # Load model globally to avoid reloading on each request
24
- # This block runs once when the FastAPI application starts.
25
- try:
26
- logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
27
- tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
- model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
29
- model.eval() # Set model to evaluation mode
30
- logger.info("Model loaded successfully.")
31
- except Exception as e:
32
- logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
33
- raise RuntimeError(f"Could not load model: {e}")
34
-
35
- class InferenceRequest(BaseModel):
36
- """
37
- Request model for the /predict endpoint.
38
- Expects a single string field 'text' containing the sentence with [MASK] tokens.
39
- """
40
- text: str
41
-
42
- class PredictionResult(BaseModel):
43
- """
44
- Response model for individual predictions from the /predict endpoint.
45
- """
46
- sequence: str # The full sequence with the predicted token filled in
47
- score: float # Confidence score of the prediction
48
- token: int # The ID of the predicted token
49
- token_str: str # The string representation of the predicted token
50
-
51
- @api_router.post(
52
- "/predict", # Prediction endpoint
53
- response_model=list[PredictionResult],
54
- summary="Predicts masked tokens in a given text using NeuroBERT-Tiny",
55
- description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
56
- )
57
- async def predict_masked_lm(request: InferenceRequest):
58
- """
59
- Predicts the most likely tokens for [MASK] positions in the input text using the NeuroBERT-Tiny model.
60
- Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
61
- """
62
- try:
63
- text = request.text
64
- logger.info(f"Received prediction request for text: '{text}'")
65
 
66
- inputs = tokenizer(text, return_tensors="pt")
 
 
 
67
 
68
- with torch.no_grad():
69
- outputs = model(**inputs)
 
 
 
 
 
70
 
71
- logits = outputs.logits
72
- masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
 
 
 
 
 
73
 
74
- # Find all masked tokens
75
- masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
76
 
77
- if not masked_token_indices.numel():
78
- logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
79
- raise HTTPException(
80
- status_code=status.HTTP_400_BAD_REQUEST,
81
- detail="Input text must contain at least one '[MASK]' token."
82
- )
 
 
83
 
84
- results = []
85
- for masked_index in masked_token_indices:
86
- # Get top 5 predictions for the masked token
87
- top_5_logits = torch.topk(logits[0, masked_index], 5).values
88
- top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
89
-
90
- for i in range(5):
91
- score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
92
- predicted_token_id = top_5_tokens[i].item()
93
- predicted_token_str = tokenizer.decode(predicted_token_id)
94
-
95
- # Replace the [MASK] with the predicted token for the full sequence
96
- temp_input_ids = inputs["input_ids"].clone()
97
- temp_input_ids[0, masked_index] = predicted_token_id
98
- full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
99
-
100
- results.append(PredictionResult(
101
- sequence=full_sequence,
102
- score=score,
103
- token=predicted_token_id,
104
- token_str=predicted_token_str
105
- ))
106
-
107
- logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
108
- return results
109
-
110
- except ValidationError as e:
111
- logger.error(f"Validation error for request: {e.errors()}")
112
- raise HTTPException(
113
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
114
- detail=e.errors()
115
- )
116
- except HTTPException:
117
- raise # Re-raise custom HTTPExceptions
118
  except Exception as e:
119
- logger.exception(f"An unexpected error occurred during prediction: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
  raise HTTPException(
121
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
122
- detail=f"An internal server error occurred: {e}"
123
  )
 
 
 
124
 
125
- @api_router.get(
126
- "/health", # Health check endpoint
127
- summary="Health Check",
128
- description="Returns a simple message indicating the API is running."
129
- )
130
- async def health_check():
131
- logger.info("Health check endpoint accessed.")
132
- return {"message": "NeuroBERT-Tiny API is running!"}
 
 
 
133
 
134
- app.include_router(api_router)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
- @app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
137
- async def catch_all(request: Request, path_name: str):
138
- logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
139
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
 
 
 
 
 
 
 
 
 
 
 
 
140
 
141
  if __name__ == "__main__":
142
  import uvicorn
143
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
 
1
+ from fastapi import FastAPI, HTTPException, BackgroundTasks
2
+ from pydantic import BaseModel, Field
3
+ from typing import List, Optional, Dict, Any
4
+ import httpx
5
+ import asyncio
6
  import logging
7
+ import time
8
+ import json
9
 
10
+ # Configure logging
11
  logging.basicConfig(level=logging.INFO)
12
  logger = logging.getLogger(__name__)
13
 
14
+ # FastAPI app
15
  app = FastAPI(
16
+ title="Ollama API Server",
17
+ description="REST API for running Ollama models",
18
+ version="1.0.0",
19
+ docs_url="/docs",
20
+ redoc_url="/redoc"
21
  )
22
 
23
+ # Ollama server configuration
24
+ OLLAMA_BASE_URL = "http://localhost:11434"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
 
26
+ # Pydantic models
27
+ class ChatMessage(BaseModel):
28
+ role: str = Field(..., description="Role of the message sender (user, assistant, system)")
29
+ content: str = Field(..., description="Content of the message")
30
 
31
+ class ChatRequest(BaseModel):
32
+ model: str = Field(..., description="Model name to use for chat")
33
+ messages: List[ChatMessage] = Field(..., description="List of chat messages")
34
+ temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
35
+ top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
36
+ max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
37
+ stream: Optional[bool] = Field(False, description="Whether to stream the response")
38
 
39
+ class GenerateRequest(BaseModel):
40
+ model: str = Field(..., description="Model name to use for generation")
41
+ prompt: str = Field(..., description="Input prompt for text generation")
42
+ temperature: Optional[float] = Field(0.7, ge=0.0, le=2.0, description="Sampling temperature")
43
+ top_p: Optional[float] = Field(0.9, ge=0.0, le=1.0, description="Top-p sampling parameter")
44
+ max_tokens: Optional[int] = Field(512, ge=1, le=4096, description="Maximum tokens to generate")
45
+ stream: Optional[bool] = Field(False, description="Whether to stream the response")
46
 
47
+ class ModelPullRequest(BaseModel):
48
+ model: str = Field(..., description="Model name to pull (e.g., 'llama2:7b')")
49
 
50
+ class ChatResponse(BaseModel):
51
+ model: str
52
+ response: str
53
+ done: bool
54
+ total_duration: Optional[int] = None
55
+ load_duration: Optional[int] = None
56
+ prompt_eval_count: Optional[int] = None
57
+ eval_count: Optional[int] = None
58
 
59
+ class GenerateResponse(BaseModel):
60
+ model: str
61
+ response: str
62
+ done: bool
63
+ total_duration: Optional[int] = None
64
+ load_duration: Optional[int] = None
65
+ prompt_eval_count: Optional[int] = None
66
+ eval_count: Optional[int] = None
67
+
68
+ # HTTP client for Ollama API
69
+ async def get_ollama_client():
70
+ return httpx.AsyncClient(timeout=300.0) # 5 minute timeout
71
+
72
+ @app.get("/health")
73
+ async def health_check():
74
+ """Health check endpoint"""
75
+ try:
76
+ async with await get_ollama_client() as client:
77
+ response = await client.get(f"{OLLAMA_BASE_URL}/api/version")
78
+ if response.status_code == 200:
79
+ return {
80
+ "status": "healthy",
81
+ "ollama_status": "running",
82
+ "ollama_version": response.json(),
83
+ "timestamp": time.time()
84
+ }
85
+ else:
86
+ return {
87
+ "status": "degraded",
88
+ "ollama_status": "error",
89
+ "error": f"Ollama returned status {response.status_code}",
90
+ "timestamp": time.time()
91
+ }
 
92
  except Exception as e:
93
+ logger.error(f"Health check failed: {e}")
94
+ return {
95
+ "status": "unhealthy",
96
+ "ollama_status": "unreachable",
97
+ "error": str(e),
98
+ "timestamp": time.time()
99
+ }
100
+
101
+ @app.get("/models")
102
+ async def list_models():
103
+ """List available models"""
104
+ try:
105
+ async with await get_ollama_client() as client:
106
+ response = await client.get(f"{OLLAMA_BASE_URL}/api/tags")
107
+ response.raise_for_status()
108
+ return response.json()
109
+ except httpx.HTTPError as e:
110
+ logger.error(f"Failed to list models: {e}")
111
+ raise HTTPException(status_code=500, detail=f"Failed to list models: {str(e)}")
112
+
113
+ @app.post("/models/pull")
114
+ async def pull_model(request: ModelPullRequest, background_tasks: BackgroundTasks):
115
+ """Pull a model from Ollama registry"""
116
+ try:
117
+ async with await get_ollama_client() as client:
118
+ # Start the pull request
119
+ pull_data = {"name": request.model}
120
+ response = await client.post(
121
+ f"{OLLAMA_BASE_URL}/api/pull",
122
+ json=pull_data,
123
+ timeout=1800.0 # 30 minute timeout for model pulling
124
+ )
125
+
126
+ if response.status_code == 200:
127
+ return {
128
+ "status": "success",
129
+ "message": f"Successfully initiated pull for model '{request.model}'",
130
+ "model": request.model
131
+ }
132
+ else:
133
+ error_detail = response.text
134
+ logger.error(f"Failed to pull model: {error_detail}")
135
+ raise HTTPException(
136
+ status_code=response.status_code,
137
+ detail=f"Failed to pull model: {error_detail}"
138
+ )
139
+ except httpx.TimeoutException:
140
  raise HTTPException(
141
+ status_code=408,
142
+ detail="Model pull request timed out. Large models may take longer to download."
143
  )
144
+ except Exception as e:
145
+ logger.error(f"Error pulling model: {e}")
146
+ raise HTTPException(status_code=500, detail=f"Error pulling model: {str(e)}")
147
 
148
+ @app.delete("/models/{model_name}")
149
+ async def delete_model(model_name: str):
150
+ """Delete a model"""
151
+ try:
152
+ async with await get_ollama_client() as client:
153
+ response = await client.delete(f"{OLLAMA_BASE_URL}/api/delete", json={"name": model_name})
154
+ response.raise_for_status()
155
+ return {"status": "success", "message": f"Model '{model_name}' deleted successfully"}
156
+ except httpx.HTTPError as e:
157
+ logger.error(f"Failed to delete model: {e}")
158
+ raise HTTPException(status_code=500, detail=f"Failed to delete model: {str(e)}")
159
 
160
+ @app.post("/chat", response_model=ChatResponse)
161
+ async def chat_with_model(request: ChatRequest):
162
+ """Chat with a model"""
163
+ try:
164
+ # Convert messages to Ollama format
165
+ chat_data = {
166
+ "model": request.model,
167
+ "messages": [{"role": msg.role, "content": msg.content} for msg in request.messages],
168
+ "stream": request.stream,
169
+ "options": {
170
+ "temperature": request.temperature,
171
+ "top_p": request.top_p,
172
+ "num_predict": request.max_tokens
173
+ }
174
+ }
175
+
176
+ async with await get_ollama_client() as client:
177
+ response = await client.post(
178
+ f"{OLLAMA_BASE_URL}/api/chat",
179
+ json=chat_data,
180
+ timeout=300.0
181
+ )
182
+ response.raise_for_status()
183
+ result = response.json()
184
+
185
+ return ChatResponse(
186
+ model=result.get("model", request.model),
187
+ response=result.get("message", {}).get("content", ""),
188
+ done=result.get("done", True),
189
+ total_duration=result.get("total_duration"),
190
+ load_duration=result.get("load_duration"),
191
+ prompt_eval_count=result.get("prompt_eval_count"),
192
+ eval_count=result.get("eval_count")
193
+ )
194
+
195
+ except httpx.HTTPError as e:
196
+ logger.error(f"Chat request failed: {e}")
197
+ if e.response.status_code == 404:
198
+ raise HTTPException(
199
+ status_code=404,
200
+ detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
201
+ )
202
+ raise HTTPException(status_code=500, detail=f"Chat request failed: {str(e)}")
203
+ except Exception as e:
204
+ logger.error(f"Unexpected error in chat: {e}")
205
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
206
+
207
+ @app.post("/generate", response_model=GenerateResponse)
208
+ async def generate_text(request: GenerateRequest):
209
+ """Generate text completion"""
210
+ try:
211
+ generate_data = {
212
+ "model": request.model,
213
+ "prompt": request.prompt,
214
+ "stream": request.stream,
215
+ "options": {
216
+ "temperature": request.temperature,
217
+ "top_p": request.top_p,
218
+ "num_predict": request.max_tokens
219
+ }
220
+ }
221
+
222
+ async with await get_ollama_client() as client:
223
+ response = await client.post(
224
+ f"{OLLAMA_BASE_URL}/api/generate",
225
+ json=generate_data,
226
+ timeout=300.0
227
+ )
228
+ response.raise_for_status()
229
+ result = response.json()
230
+
231
+ return GenerateResponse(
232
+ model=result.get("model", request.model),
233
+ response=result.get("response", ""),
234
+ done=result.get("done", True),
235
+ total_duration=result.get("total_duration"),
236
+ load_duration=result.get("load_duration"),
237
+ prompt_eval_count=result.get("prompt_eval_count"),
238
+ eval_count=result.get("eval_count")
239
+ )
240
+
241
+ except httpx.HTTPError as e:
242
+ logger.error(f"Generate request failed: {e}")
243
+ if e.response.status_code == 404:
244
+ raise HTTPException(
245
+ status_code=404,
246
+ detail=f"Model '{request.model}' not found. Try pulling it first with POST /models/pull"
247
+ )
248
+ raise HTTPException(status_code=500, detail=f"Generate request failed: {str(e)}")
249
+ except Exception as e:
250
+ logger.error(f"Unexpected error in generate: {e}")
251
+ raise HTTPException(status_code=500, detail=f"Unexpected error: {str(e)}")
252
 
253
+ @app.get("/")
254
+ async def root():
255
+ """Root endpoint with API information"""
256
+ return {
257
+ "message": "Ollama API Server",
258
+ "version": "1.0.0",
259
+ "endpoints": {
260
+ "health": "/health",
261
+ "models": "/models",
262
+ "pull_model": "/models/pull",
263
+ "chat": "/chat",
264
+ "generate": "/generate",
265
+ "docs": "/docs"
266
+ },
267
+ "status": "running"
268
+ }
269
 
270
  if __name__ == "__main__":
271
  import uvicorn
272
+ logger.info("Starting Ollama API server...")
273
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")