File size: 6,450 Bytes
1d0ba1f
 
ff570d4
 
1d0ba1f
640c287
1d0ba1f
 
 
 
 
 
 
 
 
ff570d4
 
1d0ba1f
 
 
 
 
 
 
 
 
 
 
 
 
ff570d4
 
1d0ba1f
 
 
 
ff570d4
 
 
1d0ba1f
 
 
 
 
 
 
ff570d4
1d0ba1f
 
 
 
 
 
ff570d4
1d0ba1f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ff570d4
1d0ba1f
 
 
 
ff570d4
 
1d0ba1f
 
ff570d4
 
1d0ba1f
 
5c04962
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
from fastapi import FastAPI, HTTPException, status
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging

# Configure logging to output information, warnings, and errors
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="NeuroBERT-Tiny Masked Language Model API",
    description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
    version="1.0.0"
)

# Load model globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
    logger.info("Loading tokenizer and model for boltuix/NeuroBERT-Tiny...")
    tokenizer = AutoTokenizer.from_pretrained("boltuix/NeuroBERT-Tiny")
    model = AutoModelForMaskedLM.from_pretrained("boltuix/NeuroBERT-Tiny")
    model.eval() # Set model to evaluation mode for inference
    logger.info("Model loaded successfully.")
except Exception as e:
    logger.exception("Failed to load model or tokenizer during startup!")
    # Depending on the deployment, you might want to raise an exception here
    # to prevent the app from starting if the model can't be loaded.
    # For now, we'll let it potentially start and fail on prediction.
    raise RuntimeError(f"Could not load model: {e}")

class InferenceRequest(BaseModel):
    """
    Request model for the /predict endpoint.
    Expects a single string field 'text' containing the sentence with [MASK] tokens.
    """
    text: str

class PredictionResult(BaseModel):
    """
    Response model for individual predictions from the /predict endpoint.
    """
    sequence: str    # The full sequence with the predicted token filled in
    score: float     # Confidence score of the prediction
    token: int       # The ID of the predicted token
    token_str: str   # The string representation of the predicted token

@app.post(
    "/predict",
    response_model=list[PredictionResult],
    summary="Predicts masked tokens in a given text",
    description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
    """
    Predicts the most likely tokens for [MASK] positions in the input text.
    Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
    """
    try:
        text = request.text
        logger.info(f"Received prediction request for text: '{text}'")

        # Tokenize the input text
        inputs = tokenizer(text, return_tensors="pt")

        # Perform inference without tracking gradients
        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits
        masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")

        # Find all masked token positions in the input IDs
        masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]

        if not masked_token_indices.numel():
            logger.warning("No [MASK] token found in the input text.")
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Input text must contain at least one '[MASK]' token."
            )

        results = []
        # Iterate over each masked token found in the input
        for masked_index in masked_token_indices:
            # Get top 5 predictions (logits and their corresponding token IDs) for the current masked position
            top_5_logits = torch.topk(logits[0, masked_index], 5).values
            top_5_tokens = torch.topk(logits[0, masked_index], 5).indices

            # For each of the top 5 predictions
            for i in range(5):
                # Calculate the softmax score for the predicted token
                score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
                predicted_token_id = top_5_tokens[i].item()
                predicted_token_str = tokenizer.decode(predicted_token_id)
                
                # Create a temporary input_ids tensor to replace the [MASK] token
                # with the current predicted token for generating the full sequence.
                temp_input_ids = inputs["input_ids"].clone()
                temp_input_ids[0, masked_index] = predicted_token_id
                
                # Decode the entire sequence, skipping special tokens, to get the complete predicted sentence.
                full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)

                # Append the prediction result to our list
                results.append(PredictionResult(
                    sequence=full_sequence,
                    score=score,
                    token=predicted_token_id,
                    token_str=predicted_token_str
                ))
        
        logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
        return results
    
    except ValidationError as e:
        logger.error(f"Validation error for request: {e.errors()}")
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail=e.errors()
        )
    except HTTPException:
        # Re-raise explicit HTTPExceptions (e.g., 400 for missing [MASK])
        raise
    except Exception as e:
        logger.exception(f"An unexpected error occurred during prediction: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"An internal server error occurred: {e}"
        )

@app.get(
    "/",
    summary="Health Check",
    description="Returns a simple message indicating the API is running."
)
async def root():
    """
    Provides a basic health check endpoint for the API.
    """
    logger.info("Health check endpoint accessed.")
    return {"message": "NeuroBERT-Tiny API is running!"}

# This block is for running the app directly, typically used for local development.
# In a Docker container, Uvicorn (or Gunicorn) is usually invoked via the CMD in Dockerfile.
if __name__ == "__main__":
    import uvicorn
    # The 'reload=True' is great for local development for auto-reloading changes.
    # For production in a Docker container, it's typically omitted for performance.
    uvicorn.run(app, host="0.0.0.0", port=8000, log_level="info")