Spaces:
Sleeping
Sleeping
File size: 5,058 Bytes
119dbbc 1d0ba1f 119dbbc 1d0ba1f 119dbbc 1d0ba1f 119dbbc 1d0ba1f 119dbbc 1d0ba1f 119dbbc 1d0ba1f 119dbbc |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 |
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="NeuroBERT-Tiny Masked Language Model API",
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
version="1.0.0"
)
api_router = APIRouter()
try:
logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
model.eval()
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception("Failed to load model or tokenizer during startup!")
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
text: str
class PredictionResult(BaseModel):
sequence: str
score: float
token: int
token_str: str
@api_router.post(
"/predict", # IMPORTANT: Prediction endpoint is now /predict
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text",
description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
try:
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
if not masked_token_indices.numel():
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
results = []
for masked_index in masked_token_indices:
top_5_logits = torch.topk(logits[0, masked_index], 5).values
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
for i in range(5):
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
predicted_token_id = top_5_tokens[i].item()
predicted_token_str = tokenizer.decode(predicted_token_id)
temp_input_ids = inputs["input_ids"].clone()
temp_input_ids[0, masked_index] = predicted_token_id
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
results.append(PredictionResult(
sequence=full_sequence,
score=score,
token=predicted_token_id,
token_str=predicted_token_str
))
logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
raise
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@api_router.get(
"/health", # IMPORTANT: Health check endpoint is /health
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def health_check():
logger.info("Health check endpoint accessed.")
return {"message": "NeuroBERT-Tiny API is running!"}
app.include_router(api_router)
@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|