brendon-ai commited on
Commit
de7dbb0
·
verified ·
1 Parent(s): c4bd657

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +98 -99
app.py CHANGED
@@ -1,86 +1,86 @@
1
- from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
- from pydantic import BaseModel, ValidationError
3
- from transformers import AutoTokenizer, AutoModelForMaskedLM
4
- import torch
5
- import logging
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6
 
7
- logging.basicConfig(level=logging.INFO)
8
- logger = logging.getLogger(__name__)
 
9
 
10
- app = FastAPI(
11
- title="NeuroBERT-Tiny Masked Language Model API",
12
- description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
13
- version="1.0.0"
14
- )
15
 
16
- api_router = APIRouter()
17
 
18
- try:
19
- logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
20
- tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
21
- model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
22
- model.eval()
23
- logger.info("Model loaded successfully.")
24
- except Exception as e:
25
- logger.exception("Failed to load model or tokenizer during startup!")
26
- raise RuntimeError(f"Could not load model: {e}")
27
-
28
- class InferenceRequest(BaseModel):
29
- text: str
30
-
31
- class PredictionResult(BaseModel):
32
- sequence: str
33
- score: float
34
- token: int
35
- token_str: str
36
-
37
- @api_router.post(
38
- "/predict", # IMPORTANT: Prediction endpoint is now /predict
39
- response_model=list[PredictionResult],
40
- summary="Predicts masked tokens in a given text",
41
- description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
42
- )
43
- async def predict_masked_lm(request: InferenceRequest):
44
- try:
45
- text = request.text
46
- logger.info(f"Received prediction request for text: '{text}'")
47
-
48
- inputs = tokenizer(text, return_tensors="pt")
49
- with torch.no_grad():
50
- outputs = model(**inputs)
51
-
52
- logits = outputs.logits
53
- masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
54
-
55
- masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
56
-
57
- if not masked_token_indices.numel():
58
- logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
59
- raise HTTPException(
60
- status_code=status.HTTP_400_BAD_REQUEST,
61
- detail="Input text must contain at least one '[MASK]' token."
62
- )
63
-
64
- results = []
65
- for masked_index in masked_token_indices:
66
- top_5_logits = torch.topk(logits[0, masked_index], 5).values
67
- top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
68
-
69
- for i in range(5):
70
- score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
71
- predicted_token_id = top_5_tokens[i].item()
72
- predicted_token_str = tokenizer.decode(predicted_token_id)
73
-
74
- temp_input_ids = inputs["input_ids"].clone()
75
- temp_input_ids[0, masked_index] = predicted_token_id
76
- full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
77
-
78
- results.append(PredictionResult(
79
- sequence=full_sequence,
80
- score=score,
81
- token=predicted_token_id,
82
- token_str=predicted_token_str
83
- ))
84
 
85
  logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
86
  return results
@@ -100,23 +100,22 @@
100
  detail=f"An internal server error occurred: {e}"
101
  )
102
 
103
- @api_router.get(
104
- "/health", # IMPORTANT: Health check endpoint is /health
105
- summary="Health Check",
106
- description="Returns a simple message indicating the API is running."
107
- )
108
- async def health_check():
109
- logger.info("Health check endpoint accessed.")
110
- return {"message": "NeuroBERT-Tiny API is running!"}
111
-
112
- app.include_router(api_router)
113
-
114
- @app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
115
- async def catch_all(request: Request, path_name: str):
116
- logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
117
- raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
118
-
119
- if __name__ == "__main__":
120
- import uvicorn
121
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
122
-
 
1
+ from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
+ from pydantic import BaseModel, ValidationError
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
+ import torch
5
+ import logging
6
+
7
+ logging.basicConfig(level=logging.INFO)
8
+ logger = logging.getLogger(__name__)
9
+
10
+ app = FastAPI(
11
+ title="NeuroBERT-Tiny Masked Language Model API",
12
+ description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
13
+ version="1.0.0"
14
+ )
15
+
16
+ api_router = APIRouter()
17
+
18
+ try:
19
+ logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
20
+ tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
21
+ model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
22
+ model.eval()
23
+ logger.info("Model loaded successfully.")
24
+ except Exception as e:
25
+ logger.exception("Failed to load model or tokenizer during startup!")
26
+ raise RuntimeError(f"Could not load model: {e}")
27
+
28
+ class InferenceRequest(BaseModel):
29
+ text: str
30
+
31
+ class PredictionResult(BaseModel):
32
+ sequence: str
33
+ score: float
34
+ token: int
35
+ token_str: str
36
+
37
+ @api_router.post(
38
+ "/predict", # Prediction endpoint
39
+ response_model=list[PredictionResult],
40
+ summary="Predicts masked tokens in a given text",
41
+ description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
42
+ )
43
+ async def predict_masked_lm(request: InferenceRequest):
44
+ try:
45
+ text = request.text
46
+ logger.info(f"Received prediction request for text: '{text}'")
47
 
48
+ inputs = tokenizer(text, return_tensors="pt")
49
+ with torch.no_grad():
50
+ outputs = model(**inputs)
51
 
52
+ logits = outputs.logits
53
+ masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
 
 
 
54
 
55
+ masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
56
 
57
+ if not masked_token_indices.numel():
58
+ logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
59
+ raise HTTPException(
60
+ status_code=status.HTTP_400_BAD_REQUEST,
61
+ detail="Input text must contain at least one '[MASK]' token."
62
+ )
63
+
64
+ results = []
65
+ for masked_index in masked_token_indices:
66
+ top_5_logits = torch.topk(logits[0, masked_index], 5).values
67
+ top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
68
+
69
+ for i in range(5):
70
+ score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
71
+ predicted_token_id = top_5_tokens[i].item()
72
+ predicted_token_str = tokenizer.decode(predicted_token_id)
73
+
74
+ temp_input_ids = inputs["input_ids"].clone()
75
+ temp_input_ids[0, masked_index] = predicted_token_id
76
+ full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
77
+
78
+ results.append(PredictionResult(
79
+ sequence=full_sequence,
80
+ score=score,
81
+ token=predicted_token_id,
82
+ token_str=predicted_token_str
83
+ ))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
86
  return results
 
100
  detail=f"An internal server error occurred: {e}"
101
  )
102
 
103
+ @api_router.get(
104
+ "/health", # Health check endpoint
105
+ summary="Health Check",
106
+ description="Returns a simple message indicating the API is running."
107
+ )
108
+ async def health_check():
109
+ logger.info("Health check endpoint accessed.")
110
+ return {"message": "NeuroBERT-Tiny API is running!"}
111
+
112
+ app.include_router(api_router)
113
+
114
+ @app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
115
+ async def catch_all(request: Request, path_name: str):
116
+ logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
117
+ raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
118
+
119
+ if __name__ == "__main__":
120
+ import uvicorn
121
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")