brendon-ai commited on
Commit
1d0ba1f
·
verified ·
1 Parent(s): b04a975

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +133 -51
app.py CHANGED
@@ -1,68 +1,150 @@
1
- from fastapi import FastAPI
2
- from pydantic import BaseModel
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
4
  import torch
 
5
 
6
- app = FastAPI()
 
 
 
 
 
 
 
 
7
 
8
  # Load model globally to avoid reloading on each request
9
- tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
10
- model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
11
- model.eval() # Set model to evaluation mode
 
 
 
 
 
 
 
 
 
 
12
 
13
  class InferenceRequest(BaseModel):
 
 
 
 
14
  text: str
15
 
16
  class PredictionResult(BaseModel):
17
- sequence: str
18
- score: float
19
- token: int
20
- token_str: str
 
 
 
21
 
22
- @app.post("/predict", response_model=list[PredictionResult])
 
 
 
 
 
23
  async def predict_masked_lm(request: InferenceRequest):
24
- text = request.text
25
- inputs = tokenizer(text, return_tensors="pt")
26
-
27
- with torch.no_grad():
28
- outputs = model(**inputs)
29
-
30
- logits = outputs.logits
31
- masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
32
-
33
- # Find all masked tokens
34
- masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
35
-
36
- results = []
37
- for masked_index in masked_token_indices:
38
- # Get top 5 predictions for the masked token
39
- top_5_logits = torch.topk(logits[0, masked_index], 5).values
40
- top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
41
-
42
- for i in range(5):
43
- score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
44
- predicted_token_id = top_5_tokens[i].item()
45
- predicted_token_str = tokenizer.decode(predicted_token_id)
46
-
47
- # Replace the [MASK] with the predicted token for the full sequence
48
- # Create a temporary input_ids tensor to get the sequence
49
- temp_input_ids = inputs["input_ids"].clone()
50
- temp_input_ids[0, masked_index] = predicted_token_id
51
- full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
52
-
53
- results.append(PredictionResult(
54
- sequence=full_sequence,
55
- score=score,
56
- token=predicted_token_id,
57
- token_str=predicted_token_str
58
- ))
59
- return results
60
-
61
- # Optional: A simple health check endpoint
62
- @app.get("/")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63
  async def root():
 
 
 
 
64
  return {"message": "NeuroBERT-Tiny API is running!"}
65
 
 
 
66
  if __name__ == "__main__":
67
  import uvicorn
68
- uvicorn.run(app, host="0.0.0.0", port=8000)
 
 
 
1
+ from fastapi import FastAPI, HTTPException, status
2
+ from pydantic import BaseModel, ValidationError
3
  from transformers import AutoTokenizer, AutoModelForMaskedLM
4
  import torch
5
+ import logging
6
 
7
+ # Configure logging to output information, warnings, and errors
8
+ logging.basicConfig(level=logging.INFO)
9
+ logger = logging.getLogger(__name__)
10
+
11
+ app = FastAPI(
12
+ title="NeuroBERT-Tiny Masked Language Model API",
13
+ description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
14
+ version="1.0.0"
15
+ )
16
 
17
  # Load model globally to avoid reloading on each request
18
+ # This block runs once when the FastAPI application starts.
19
+ try:
20
+ logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
21
+ tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
22
+ model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
23
+ model.eval() # Set model to evaluation mode for inference
24
+ logger.info("Model loaded successfully.")
25
+ except Exception as e:
26
+ logger.exception("Failed to load model or tokenizer during startup!")
27
+ # Depending on the deployment, you might want to raise an exception here
28
+ # to prevent the app from starting if the model can't be loaded.
29
+ # For now, we'll let it potentially start and fail on prediction.
30
+ raise RuntimeError(f"Could not load model: {e}")
31
 
32
  class InferenceRequest(BaseModel):
33
+ """
34
+ Request model for the /predict endpoint.
35
+ Expects a single string field 'text' containing the sentence with [MASK] tokens.
36
+ """
37
  text: str
38
 
39
  class PredictionResult(BaseModel):
40
+ """
41
+ Response model for individual predictions from the /predict endpoint.
42
+ """
43
+ sequence: str # The full sequence with the predicted token filled in
44
+ score: float # Confidence score of the prediction
45
+ token: int # The ID of the predicted token
46
+ token_str: str # The string representation of the predicted token
47
 
48
+ @app.post(
49
+ "/predict",
50
+ response_model=list[PredictionResult],
51
+ summary="Predicts masked tokens in a given text",
52
+ description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
53
+ )
54
  async def predict_masked_lm(request: InferenceRequest):
55
+ """
56
+ Predicts the most likely tokens for [MASK] positions in the input text.
57
+ Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
58
+ """
59
+ try:
60
+ text = request.text
61
+ logger.info(f"Received prediction request for text: '{text}'")
62
+
63
+ # Tokenize the input text
64
+ inputs = tokenizer(text, return_tensors="pt")
65
+
66
+ # Perform inference without tracking gradients
67
+ with torch.no_grad():
68
+ outputs = model(**inputs)
69
+
70
+ logits = outputs.logits
71
+ masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
72
+
73
+ # Find all masked token positions in the input IDs
74
+ masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
75
+
76
+ if not masked_token_indices.numel():
77
+ logger.warning("No [MASK] token found in the input text.")
78
+ raise HTTPException(
79
+ status_code=status.HTTP_400_BAD_REQUEST,
80
+ detail="Input text must contain at least one '[MASK]' token."
81
+ )
82
+
83
+ results = []
84
+ # Iterate over each masked token found in the input
85
+ for masked_index in masked_token_indices:
86
+ # Get top 5 predictions (logits and their corresponding token IDs) for the current masked position
87
+ top_5_logits = torch.topk(logits[0, masked_index], 5).values
88
+ top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
89
+
90
+ # For each of the top 5 predictions
91
+ for i in range(5):
92
+ # Calculate the softmax score for the predicted token
93
+ score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
94
+ predicted_token_id = top_5_tokens[i].item()
95
+ predicted_token_str = tokenizer.decode(predicted_token_id)
96
+
97
+ # Create a temporary input_ids tensor to replace the [MASK] token
98
+ # with the current predicted token for generating the full sequence.
99
+ temp_input_ids = inputs["input_ids"].clone()
100
+ temp_input_ids[0, masked_index] = predicted_token_id
101
+
102
+ # Decode the entire sequence, skipping special tokens, to get the complete predicted sentence.
103
+ full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
104
+
105
+ # Append the prediction result to our list
106
+ results.append(PredictionResult(
107
+ sequence=full_sequence,
108
+ score=score,
109
+ token=predicted_token_id,
110
+ token_str=predicted_token_str
111
+ ))
112
+
113
+ logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
114
+ return results
115
+
116
+ except ValidationError as e:
117
+ logger.error(f"Validation error for request: {e.errors()}")
118
+ raise HTTPException(
119
+ status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
120
+ detail=e.errors()
121
+ )
122
+ except HTTPException:
123
+ # Re-raise explicit HTTPExceptions (e.g., 400 for missing [MASK])
124
+ raise
125
+ except Exception as e:
126
+ logger.exception(f"An unexpected error occurred during prediction: {e}")
127
+ raise HTTPException(
128
+ status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
129
+ detail=f"An internal server error occurred: {e}"
130
+ )
131
+
132
+ @app.get(
133
+ "/",
134
+ summary="Health Check",
135
+ description="Returns a simple message indicating the API is running."
136
+ )
137
  async def root():
138
+ """
139
+ Provides a basic health check endpoint for the API.
140
+ """
141
+ logger.info("Health check endpoint accessed.")
142
  return {"message": "NeuroBERT-Tiny API is running!"}
143
 
144
+ # This block is for running the app directly, typically used for local development.
145
+ # In a Docker container, Uvicorn (or Gunicorn) is usually invoked via the CMD in Dockerfile.
146
  if __name__ == "__main__":
147
  import uvicorn
148
+ # The 'reload=True' is great for local development for auto-reloading changes.
149
+ # For production in a Docker container, it's typically omitted for performance.
150
+ uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")