Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, status | |
from pydantic import BaseModel, ValidationError | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
import logging | |
# Configure logging to output information, warnings, and errors | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="NeuroBERT-Tiny Masked Language Model API", | |
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.", | |
version="1.0.0" | |
) | |
# Load model globally to avoid reloading on each request | |
# This block runs once when the FastAPI application starts. | |
try: | |
logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...") | |
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny") | |
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny") | |
model.eval() # Set model to evaluation mode for inference | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.exception("Failed to load model or tokenizer during startup!") | |
# Depending on the deployment, you might want to raise an exception here | |
# to prevent the app from starting if the model can't be loaded. | |
# For now, we'll let it potentially start and fail on prediction. | |
raise RuntimeError(f"Could not load model: {e}") | |
class InferenceRequest(BaseModel): | |
""" | |
Request model for the /predict endpoint. | |
Expects a single string field 'text' containing the sentence with [MASK] tokens. | |
""" | |
text: str | |
class PredictionResult(BaseModel): | |
""" | |
Response model for individual predictions from the /predict endpoint. | |
""" | |
sequence: str # The full sequence with the predicted token filled in | |
score: float # Confidence score of the prediction | |
token: int # The ID of the predicted token | |
token_str: str # The string representation of the predicted token | |
async def predict_masked_lm(request: InferenceRequest): | |
""" | |
Predicts the most likely tokens for [MASK] positions in the input text. | |
Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details. | |
""" | |
try: | |
text = request.text | |
logger.info(f"Received prediction request for text: '{text}'") | |
# Tokenize the input text | |
inputs = tokenizer(text, return_tensors="pt") | |
# Perform inference without tracking gradients | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]") | |
# Find all masked token positions in the input IDs | |
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1] | |
if not masked_token_indices.numel(): | |
logger.warning("No [MASK] token found in the input text.") | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Input text must contain at least one '[MASK]' token." | |
) | |
results = [] | |
# Iterate over each masked token found in the input | |
for masked_index in masked_token_indices: | |
# Get top 5 predictions (logits and their corresponding token IDs) for the current masked position | |
top_5_logits = torch.topk(logits[0, masked_index], 5).values | |
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices | |
# For each of the top 5 predictions | |
for i in range(5): | |
# Calculate the softmax score for the predicted token | |
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item() | |
predicted_token_id = top_5_tokens[i].item() | |
predicted_token_str = tokenizer.decode(predicted_token_id) | |
# Create a temporary input_ids tensor to replace the [MASK] token | |
# with the current predicted token for generating the full sequence. | |
temp_input_ids = inputs["input_ids"].clone() | |
temp_input_ids[0, masked_index] = predicted_token_id | |
# Decode the entire sequence, skipping special tokens, to get the complete predicted sentence. | |
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True) | |
# Append the prediction result to our list | |
results.append(PredictionResult( | |
sequence=full_sequence, | |
score=score, | |
token=predicted_token_id, | |
token_str=predicted_token_str | |
)) | |
logger.info(f"Successfully processed request. Returning {len(results)} predictions.") | |
return results | |
except ValidationError as e: | |
logger.error(f"Validation error for request: {e.errors()}") | |
raise HTTPException( | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
detail=e.errors() | |
) | |
except HTTPException: | |
# Re-raise explicit HTTPExceptions (e.g., 400 for missing [MASK]) | |
raise | |
except Exception as e: | |
logger.exception(f"An unexpected error occurred during prediction: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"An internal server error occurred: {e}" | |
) | |
async def root(): | |
""" | |
Provides a basic health check endpoint for the API. | |
""" | |
logger.info("Health check endpoint accessed.") | |
return {"message": "NeuroBERT-Tiny API is running!"} | |
# This block is for running the app directly, typically used for local development. | |
# In a Docker container, Uvicorn (or Gunicorn) is usually invoked via the CMD in Dockerfile. | |
if __name__ == "__main__": | |
import uvicorn | |
# The 'reload=True' is great for local development for auto-reloading changes. | |
# For production in a Docker container, it's typically omitted for performance. | |
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |