File size: 2,347 Bytes
ff570d4
 
 
 
640c287
ff570d4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
from fastapi import FastAPI
from pydantic import BaseModel
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch

app = FastAPI()

# Load model globally to avoid reloading on each request
tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
model.eval() # Set model to evaluation mode

class InferenceRequest(BaseModel):
    text: str

class PredictionResult(BaseModel):
    sequence: str
    score: float
    token: int
    token_str: str

@app.post("/predict", response_model=list[PredictionResult])
async def predict_masked_lm(request: InferenceRequest):
    text = request.text
    inputs = tokenizer(text, return_tensors="pt")

    with torch.no_grad():
        outputs = model(**inputs)

    logits = outputs.logits
    masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")

    # Find all masked tokens
    masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]

    results = []
    for masked_index in masked_token_indices:
        # Get top 5 predictions for the masked token
        top_5_logits = torch.topk(logits[0, masked_index], 5).values
        top_5_tokens = torch.topk(logits[0, masked_index], 5).indices

        for i in range(5):
            score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
            predicted_token_id = top_5_tokens[i].item()
            predicted_token_str = tokenizer.decode(predicted_token_id)

            # Replace the [MASK] with the predicted token for the full sequence
            # Create a temporary input_ids tensor to get the sequence
            temp_input_ids = inputs["input_ids"].clone()
            temp_input_ids[0, masked_index] = predicted_token_id
            full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)

            results.append(PredictionResult(
                sequence=full_sequence,
                score=score,
                token=predicted_token_id,
                token_str=predicted_token_str
            ))
    return results

# Optional: A simple health check endpoint
@app.get("/")
async def root():
    return {"message": "NeuroBERT-Tiny API is running!"}

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)