Spaces:
Sleeping
Sleeping
File size: 6,450 Bytes
1d0ba1f ff570d4 1d0ba1f 640c287 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f ff570d4 1d0ba1f 5c04962 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging
# Configure logging to output information, warnings, and errors
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="NeuroBERT-Tiny Masked Language Model API",
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
version="1.0.0"
)
# Load model globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
model.eval() # Set model to evaluation mode for inference
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception("Failed to load model or tokenizer during startup!")
# Depending on the deployment, you might want to raise an exception here
# to prevent the app from starting if the model can't be loaded.
# For now, we'll let it potentially start and fail on prediction.
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
"""
Request model for the /predict endpoint.
Expects a single string field 'text' containing the sentence with [MASK] tokens.
"""
text: str
class PredictionResult(BaseModel):
"""
Response model for individual predictions from the /predict endpoint.
"""
sequence: str # The full sequence with the predicted token filled in
score: float # Confidence score of the prediction
token: int # The ID of the predicted token
token_str: str # The string representation of the predicted token
@app.post(
"/predict",
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text",
description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
"""
Predicts the most likely tokens for [MASK] positions in the input text.
Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
"""
try:
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
# Tokenize the input text
inputs = tokenizer(text, return_tensors="pt")
# Perform inference without tracking gradients
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
# Find all masked token positions in the input IDs
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
if not masked_token_indices.numel():
logger.warning("No [MASK] token found in the input text.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
results = []
# Iterate over each masked token found in the input
for masked_index in masked_token_indices:
# Get top 5 predictions (logits and their corresponding token IDs) for the current masked position
top_5_logits = torch.topk(logits[0, masked_index], 5).values
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
# For each of the top 5 predictions
for i in range(5):
# Calculate the softmax score for the predicted token
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
predicted_token_id = top_5_tokens[i].item()
predicted_token_str = tokenizer.decode(predicted_token_id)
# Create a temporary input_ids tensor to replace the [MASK] token
# with the current predicted token for generating the full sequence.
temp_input_ids = inputs["input_ids"].clone()
temp_input_ids[0, masked_index] = predicted_token_id
# Decode the entire sequence, skipping special tokens, to get the complete predicted sentence.
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
# Append the prediction result to our list
results.append(PredictionResult(
sequence=full_sequence,
score=score,
token=predicted_token_id,
token_str=predicted_token_str
))
logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
# Re-raise explicit HTTPExceptions (e.g., 400 for missing [MASK])
raise
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@app.get(
"/",
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def root():
"""
Provides a basic health check endpoint for the API.
"""
logger.info("Health check endpoint accessed.")
return {"message": "NeuroBERT-Tiny API is running!"}
# This block is for running the app directly, typically used for local development.
# In a Docker container, Uvicorn (or Gunicorn) is usually invoked via the CMD in Dockerfile.
if __name__ == "__main__":
import uvicorn
# The 'reload=True' is great for local development for auto-reloading changes.
# For production in a Docker container, it's typically omitted for performance.
uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info") |