Spaces:
Sleeping
Sleeping
from fastapi import FastAPI, HTTPException, status, APIRouter, Request | |
from pydantic import BaseModel, ValidationError | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
import logging | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
app = FastAPI( | |
title="NeuroBERT-Tiny Masked Language Model API", | |
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.", | |
version="1.0.0" | |
) | |
api_router = APIRouter() | |
try: | |
logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...") | |
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny") | |
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny") | |
model.eval() | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.exception("Failed to load model or tokenizer during startup!") | |
raise RuntimeError(f"Could not load model: {e}") | |
class InferenceRequest(BaseModel): | |
text: str | |
class PredictionResult(BaseModel): | |
sequence: str | |
score: float | |
token: int | |
token_str: str | |
async def predict_masked_lm(request: InferenceRequest): | |
try: | |
text = request.text | |
logger.info(f"Received prediction request for text: '{text}'") | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]") | |
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1] | |
if not masked_token_indices.numel(): | |
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.") | |
raise HTTPException( | |
status_code=status.HTTP_400_BAD_REQUEST, | |
detail="Input text must contain at least one '[MASK]' token." | |
) | |
results = [] | |
for masked_index in masked_token_indices: | |
top_5_logits = torch.topk(logits[0, masked_index], 5).values | |
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices | |
for i in range(5): | |
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item() | |
predicted_token_id = top_5_tokens[i].item() | |
predicted_token_str = tokenizer.decode(predicted_token_id) | |
temp_input_ids = inputs["input_ids"].clone() | |
temp_input_ids[0, masked_index] = predicted_token_id | |
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True) | |
results.append(PredictionResult( | |
sequence=full_sequence, | |
score=score, | |
token=predicted_token_id, | |
token_str=predicted_token_str | |
)) | |
logger.info(f"Successfully processed request. Returning {len(results)} predictions.") | |
return results | |
except ValidationError as e: | |
logger.error(f"Validation error for request: {e.errors()}") | |
raise HTTPException( | |
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, | |
detail=e.errors() | |
) | |
except HTTPException: | |
raise | |
except Exception as e: | |
logger.exception(f"An unexpected error occurred during prediction: {e}") | |
raise HTTPException( | |
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
detail=f"An internal server error occurred: {e}" | |
) | |
async def health_check(): | |
logger.info("Health check endpoint accessed.") | |
return {"message": "NeuroBERT-Tiny API is running!"} | |
app.include_router(api_router) | |
async def catch_all(request: Request, path_name: str): | |
logger.warning(f"Unhandled route accessed: {request.method} {path_name}") | |
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found") | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") | |