faq / app.py
brendon-ai's picture
Update app.py
119dbbc verified
raw
history blame
5.06 kB
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="NeuroBERT-Tiny Masked Language Model API",
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
version="1.0.0"
)
api_router = APIRouter()
try:
logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
model.eval()
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception("Failed to load model or tokenizer during startup!")
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
text: str
class PredictionResult(BaseModel):
sequence: str
score: float
token: int
token_str: str
@api_router.post(
"/predict", # IMPORTANT: Prediction endpoint is now /predict
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text",
description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
try:
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
if not masked_token_indices.numel():
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
results = []
for masked_index in masked_token_indices:
top_5_logits = torch.topk(logits[0, masked_index], 5).values
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
for i in range(5):
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
predicted_token_id = top_5_tokens[i].item()
predicted_token_str = tokenizer.decode(predicted_token_id)
temp_input_ids = inputs["input_ids"].clone()
temp_input_ids[0, masked_index] = predicted_token_id
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
results.append(PredictionResult(
sequence=full_sequence,
score=score,
token=predicted_token_id,
token_str=predicted_token_str
))
logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@api_router.get(
"/health", # IMPORTANT: Health check endpoint is /health
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def health_check():
logger.info("Health check endpoint accessed.")
return {"message": "NeuroBERT-Tiny API is running!"}
app.include_router(api_router)
@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")