Spaces:
Running
Running
from fastapi import FastAPI | |
from pydantic import BaseModel | |
from transformers import AutoTokenizer, AutoModelForMaskedLM | |
import torch | |
app = FastAPI() | |
# Load model globally to avoid reloading on each request | |
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny") | |
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny") | |
model.eval() # Set model to evaluation mode | |
class InferenceRequest(BaseModel): | |
text: str | |
class PredictionResult(BaseModel): | |
sequence: str | |
score: float | |
token: int | |
token_str: str | |
async def predict_masked_lm(request: InferenceRequest): | |
text = request.text | |
inputs = tokenizer(text, return_tensors="pt") | |
with torch.no_grad(): | |
outputs = model(**inputs) | |
logits = outputs.logits | |
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]") | |
# Find all masked tokens | |
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1] | |
results = [] | |
for masked_index in masked_token_indices: | |
# Get top 5 predictions for the masked token | |
top_5_logits = torch.topk(logits[0, masked_index], 5).values | |
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices | |
for i in range(5): | |
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item() | |
predicted_token_id = top_5_tokens[i].item() | |
predicted_token_str = tokenizer.decode(predicted_token_id) | |
# Replace the [MASK] with the predicted token for the full sequence | |
# Create a temporary input_ids tensor to get the sequence | |
temp_input_ids = inputs["input_ids"].clone() | |
temp_input_ids[0, masked_index] = predicted_token_id | |
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True) | |
results.append(PredictionResult( | |
sequence=full_sequence, | |
score=score, | |
token=predicted_token_id, | |
token_str=predicted_token_str | |
)) | |
return results | |
# Optional: A simple health check endpoint | |
async def root(): | |
return {"message": "NeuroBERT-Tiny API is running!"} | |
if __name__ == "__main__": | |
import uvicorn | |
uvicorn.run(app, host="0.0.0.0", port=8000) |