faq / app.py
brendon-ai's picture
Update app.py
f97c475 verified
raw
history blame
5.79 kB
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="NeuroBERT-Tiny Masked Language Model API",
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
version="1.0.0"
)
api_router = APIRouter()
# --- NeuroBERT-Tiny Model Configuration ---
# Using boltuix/NeuroBERT-Tiny for Masked Language Modeling.
MODEL_NAME = "boltuix/NeuroBERT-Tiny"
# ----------------------------------------
# Load model globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
"""
Request model for the /predict endpoint.
Expects a single string field 'text' containing the sentence with [MASK] tokens.
"""
text: str
class PredictionResult(BaseModel):
"""
Response model for individual predictions from the /predict endpoint.
"""
sequence: str # The full sequence with the predicted token filled in
score: float # Confidence score of the prediction
token: int # The ID of the predicted token
token_str: str # The string representation of the predicted token
@api_router.post(
"/predict", # Prediction endpoint
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text using NeuroBERT-Tiny",
description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
"""
Predicts the most likely tokens for [MASK] positions in the input text using the NeuroBERT-Tiny model.
Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
"""
try:
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
# Find all masked tokens
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
if not masked_token_indices.numel():
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
results = []
for masked_index in masked_token_indices:
# Get top 5 predictions for the masked token
top_5_logits = torch.topk(logits[0, masked_index], 5).values
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
for i in range(5):
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
predicted_token_id = top_5_tokens[i].item()
predicted_token_str = tokenizer.decode(predicted_token_id)
# Replace the [MASK] with the predicted token for the full sequence
temp_input_ids = inputs["input_ids"].clone()
temp_input_ids[0, masked_index] = predicted_token_id
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
results.append(PredictionResult(
sequence=full_sequence,
score=score,
token=predicted_token_id,
token_str=predicted_token_str
))
logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
raise # Re-raise custom HTTPExceptions
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@api_router.get(
"/health", # Health check endpoint
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def health_check():
logger.info("Health check endpoint accessed.")
return {"message": "NeuroBERT-Tiny API is running!"}
app.include_router(api_router)
@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")