Spaces:
Sleeping
Sleeping
File size: 5,790 Bytes
de7dbb0 f97c475 de7dbb0 f97c475 de7dbb0 f97c475 665d3e9 f97c475 665d3e9 de7dbb0 665d3e9 f97c475 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 f97c475 de7dbb0 f97c475 de7dbb0 f97c475 de7dbb0 665d3e9 f97c475 665d3e9 f97c475 665d3e9 f97c475 665d3e9 f97c475 1d0ba1f de7dbb0 f97c475 de7dbb0 f97c475 de7dbb0 119dbbc f97c475 665d3e9 cc2745c 665d3e9 cc2745c 1d0ba1f de7dbb0 f97c475 de7dbb0 f97c475 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 |
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="NeuroBERT-Tiny Masked Language Model API",
description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
version="1.0.0"
)
api_router = APIRouter()
# --- NeuroBERT-Tiny Model Configuration ---
# Using boltuix/NeuroBERT-Tiny for Masked Language Modeling.
MODEL_NAME = "boltuix/NeuroBERT-Tiny"
# ----------------------------------------
# Load model globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
model.eval() # Set model to evaluation mode
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
"""
Request model for the /predict endpoint.
Expects a single string field 'text' containing the sentence with [MASK] tokens.
"""
text: str
class PredictionResult(BaseModel):
"""
Response model for individual predictions from the /predict endpoint.
"""
sequence: str # The full sequence with the predicted token filled in
score: float # Confidence score of the prediction
token: int # The ID of the predicted token
token_str: str # The string representation of the predicted token
@api_router.post(
"/predict", # Prediction endpoint
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text using NeuroBERT-Tiny",
description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
"""
Predicts the most likely tokens for [MASK] positions in the input text using the NeuroBERT-Tiny model.
Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
"""
try:
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
# Find all masked tokens
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
if not masked_token_indices.numel():
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
results = []
for masked_index in masked_token_indices:
# Get top 5 predictions for the masked token
top_5_logits = torch.topk(logits[0, masked_index], 5).values
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
for i in range(5):
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
predicted_token_id = top_5_tokens[i].item()
predicted_token_str = tokenizer.decode(predicted_token_id)
# Replace the [MASK] with the predicted token for the full sequence
temp_input_ids = inputs["input_ids"].clone()
temp_input_ids[0, masked_index] = predicted_token_id
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
results.append(PredictionResult(
sequence=full_sequence,
score=score,
token=predicted_token_id,
token_str=predicted_token_str
))
logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
raise # Re-raise custom HTTPExceptions
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@api_router.get(
"/health", # Health check endpoint
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def health_check():
logger.info("Health check endpoint accessed.")
return {"message": "NeuroBERT-Tiny API is running!"}
app.include_router(api_router)
@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
|