faq / app.py
brendon-ai's picture
Update app.py
ff570d4 verified
raw
history blame
2.35 kB
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
app = FastAPI()
# Load model globally to avoid reloading on each request
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
model.eval() # Set model to evaluation mode
class InferenceRequest(BaseModel):
text: str
class PredictionResult(BaseModel):
sequence: str
score: float
token: int
token_str: str
@app.post("/predict", response_model=list[PredictionResult])
async def predict_masked_lm(request: InferenceRequest):
text = request.text
inputs = tokenizer(text, return_tensors="pt")
with torch.no_grad():
outputs = model(**inputs)
logits = outputs.logits
masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
# Find all masked tokens
masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
results = []
for masked_index in masked_token_indices:
# Get top 5 predictions for the masked token
top_5_logits = torch.topk(logits[0, masked_index], 5).values
top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
for i in range(5):
score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
predicted_token_id = top_5_tokens[i].item()
predicted_token_str = tokenizer.decode(predicted_token_id)
# Replace the [MASK] with the predicted token for the full sequence
# Create a temporary input_ids tensor to get the sequence
temp_input_ids = inputs["input_ids"].clone()
temp_input_ids[0, masked_index] = predicted_token_id
full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
results.append(PredictionResult(
sequence=full_sequence,
score=score,
token=predicted_token_id,
token_str=predicted_token_str
))
return results
# Optional: A simple health check endpoint
@app.get("/")
async def root():
return {"message": "NeuroBERT-Tiny API is running!"}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)