brendon-ai commited on
Commit
119dbbc
·
verified ·
1 Parent(s): 32f060b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +115 -154
app.py CHANGED
@@ -1,161 +1,122 @@
1
- from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
- from pydantic import BaseModel, ValidationError
3
- from transformers import AutoTokenizer, AutoModelForMaskedLM
4
- import torch
5
- import logging
6
-
7
- # Configure logging to output information, warnings, and errors
8
- logging.basicConfig(level=logging.INFO)
9
- logger = logging.getLogger(__name__)
10
-
11
- app = FastAPI(
12
- title="NeuroBERT-Tiny Masked Language Model API",
13
- description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
14
- version="1.0.0"
15
- )
16
-
17
- # Create an API Router to manage endpoints.
18
- # This approach can sometimes help with routing issues in proxied environments.
19
- api_router = APIRouter()
20
-
21
- # Load model globally to avoid reloading on each request
22
- # This block runs once when the FastAPI application starts.
23
- try:
24
- logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
25
- tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
26
- model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
27
- model.eval() # Set model to evaluation mode for inference
28
- logger.info("Model loaded successfully.")
29
- except Exception as e:
30
- logger.exception("Failed to load model or tokenizer during startup!")
31
- raise RuntimeError(f"Could not load model: {e}")
32
-
33
- class InferenceRequest(BaseModel):
34
- """
35
- Request model for the main prediction endpoint.
36
- Expects a single string field 'text' containing the sentence with [MASK] tokens.
37
- """
38
- text: str
39
-
40
- class PredictionResult(BaseModel):
41
- """
42
- Response model for individual predictions from the API.
43
- """
44
- sequence: str # The full sequence with the predicted token filled in
45
- score: float # Confidence score of the prediction
46
- token: int # The ID of the predicted token
47
- token_str: str # The string representation of the predicted token
48
-
49
- @api_router.post(
50
- "/", # Changed from "/predict" to "/" to funnel all POST requests to the root
51
- response_model=list[PredictionResult],
52
- summary="Predicts masked tokens in a given text",
53
- description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
54
- )
55
- async def predict_masked_lm(request: InferenceRequest):
56
- """
57
- Predicts the most likely tokens for [MASK] positions in the input text.
58
- Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
59
- """
60
- try:
61
- text = request.text
62
- logger.info(f"Received prediction request for text: '{text}'")
63
-
64
- # Tokenize the input text
65
- inputs = tokenizer(text, return_tensors="pt")
66
 
67
- # Perform inference without tracking gradients
68
- with torch.no_grad():
69
- outputs = model(**inputs)
70
 
71
- logits = outputs.logits
72
- masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
 
 
 
73
 
74
- # Find all masked token positions in the input IDs
75
- masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
76
 
77
- if not masked_token_indices.numel():
78
- logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
79
  raise HTTPException(
80
- status_code=status.HTTP_400_BAD_REQUEST,
81
- detail="Input text must contain at least one '[MASK]' token."
 
 
 
 
 
 
 
 
82
  )
83
 
84
- results = []
85
- # Iterate over each masked token found in the input
86
- for masked_index in masked_token_indices:
87
- # Get top 5 predictions (logits and their corresponding token IDs) for the current masked position
88
- top_5_logits = torch.topk(logits[0, masked_index], 5).values
89
- top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
90
-
91
- # For each of the top 5 predictions
92
- for i in range(5):
93
- # Calculate the softmax score for the predicted token
94
- score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
95
- predicted_token_id = top_5_tokens[i].item()
96
- predicted_token_str = tokenizer.decode(predicted_token_id)
97
-
98
- # Create a temporary input_ids tensor to replace the [MASK] token
99
- # with the current predicted token for generating the full sequence.
100
- temp_input_ids = inputs["input_ids"].clone()
101
- temp_input_ids[0, masked_index] = predicted_token_id
102
-
103
- # Decode the entire sequence, skipping special tokens, to get the complete predicted sentence.
104
- full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
105
-
106
- # Append the prediction result to our list
107
- results.append(PredictionResult(
108
- sequence=full_sequence,
109
- score=score,
110
- token=predicted_token_id,
111
- token_str=predicted_token_str
112
- ))
113
-
114
- logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
115
- return results
116
-
117
- except ValidationError as e:
118
- logger.error(f"Validation error for request: {e.errors()}")
119
- raise HTTPException(
120
- status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
121
- detail=e.errors()
122
- )
123
- except HTTPException:
124
- # Re-raise explicit HTTPExceptions (e.g., 400 for missing [MASK])
125
- raise
126
- except Exception as e:
127
- logger.exception(f"An unexpected error occurred during prediction: {e}")
128
- raise HTTPException(
129
- status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
130
- detail=f"An internal server error occurred: {e}"
131
- )
132
-
133
- @api_router.get(
134
- "/health", # Health check moved to /health
135
- summary="Health Check",
136
- description="Returns a simple message indicating the API is running."
137
- )
138
- async def health_check():
139
- """
140
- Provides a basic health check endpoint for the API.
141
- """
142
- logger.info("Health check endpoint accessed.")
143
- return {"message": "NeuroBERT-Tiny API is running!"}
144
-
145
- # Include the API router in the main FastAPI application
146
- app.include_router(api_router)
147
-
148
- # Optional: Add a catch-all route for any unhandled paths.
149
- # This can help log when requests are hitting the app but to an unknown path.
150
- # This should now catch anything not / or /health
151
- @app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
152
- async def catch_all(request: Request, path_name: str):
153
- logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
154
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
155
-
156
-
157
- # This block is for running the app directly, typically used for local development.
158
- # In a Docker container, Uvicorn (or Gunicorn) is usually invoked via the CMD in Dockerfile.
159
- if __name__ == "__main__":
160
- import uvicorn
161
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
1
+ from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
+ from pydantic import BaseModel, ValidationError
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+ import torch
5
+ import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
 
9
 
10
+ app = FastAPI(
11
+ title="NeuroBERT-Tiny Masked Language Model API",
12
+ description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
13
+ version="1.0.0"
14
+ )
15
 
16
+ api_router = APIRouter()
 
17
 
18
+ try:
19
+ logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
20
+ tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
21
+ model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
22
+ model.eval()
23
+ logger.info("Model loaded successfully.")
24
+ except Exception as e:
25
+ logger.exception("Failed to load model or tokenizer during startup!")
26
+ raise RuntimeError(f"Could not load model: {e}")
27
+
28
+ class InferenceRequest(BaseModel):
29
+ text: str
30
+
31
+ class PredictionResult(BaseModel):
32
+ sequence: str
33
+ score: float
34
+ token: int
35
+ token_str: str
36
+
37
+ @api_router.post(
38
+ "/predict", # IMPORTANT: Prediction endpoint is now /predict
39
+ response_model=list[PredictionResult],
40
+ summary="Predicts masked tokens in a given text",
41
+ description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
42
+ )
43
+ async def predict_masked_lm(request: InferenceRequest):
44
+ try:
45
+ text = request.text
46
+ logger.info(f"Received prediction request for text: '{text}'")
47
+
48
+ inputs = tokenizer(text, return_tensors="pt")
49
+ with torch.no_grad():
50
+ outputs = model(**inputs)
51
+
52
+ logits = outputs.logits
53
+ masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
54
+
55
+ masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
56
+
57
+ if not masked_token_indices.numel():
58
+ logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
59
+ raise HTTPException(
60
+ status_code=status.HTTP_400_BAD_REQUEST,
61
+ detail="Input text must contain at least one '[MASK]' token."
62
+ )
63
+
64
+ results = []
65
+ for masked_index in masked_token_indices:
66
+ top_5_logits = torch.topk(logits[0, masked_index], 5).values
67
+ top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
68
+
69
+ for i in range(5):
70
+ score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
71
+ predicted_token_id = top_5_tokens[i].item()
72
+ predicted_token_str = tokenizer.decode(predicted_token_id)
73
+
74
+ temp_input_ids = inputs["input_ids"].clone()
75
+ temp_input_ids[0, masked_index] = predicted_token_id
76
+ full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
77
+
78
+ results.append(PredictionResult(
79
+ sequence=full_sequence,
80
+ score=score,
81
+ token=predicted_token_id,
82
+ token_str=predicted_token_str
83
+ ))
84
+
85
+ logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
86
+ return results
87
+
88
+ except ValidationError as e:
89
+ logger.error(f"Validation error for request: {e.errors()}")
90
  raise HTTPException(
91
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
92
+ detail=e.errors()
93
+ )
94
+ except HTTPException:
95
+ raise
96
+ except Exception as e:
97
+ logger.exception(f"An unexpected error occurred during prediction: {e}")
98
+ raise HTTPException(
99
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
100
+ detail=f"An internal server error occurred: {e}"
101
  )
102
 
103
+ @api_router.get(
104
+ "/health", # IMPORTANT: Health check endpoint is /health
105
+ summary="Health Check",
106
+ description="Returns a simple message indicating the API is running."
107
+ )
108
+ async def health_check():
109
+ logger.info("Health check endpoint accessed.")
110
+ return {"message": "NeuroBERT-Tiny API is running!"}
111
+
112
+ app.include_router(api_router)
113
+
114
+ @app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
115
+ async def catch_all(request: Request, path_name: str):
116
+ logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
117
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
118
+
119
+ if __name__ == "__main__":
120
+ import uvicorn
121
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
122
+