File size: 5,790 Bytes
de7dbb0
 
f97c475
de7dbb0
 
 
 
 
 
 
f97c475
 
de7dbb0
 
 
 
 
f97c475
 
 
 
665d3e9
f97c475
665d3e9
de7dbb0
665d3e9
 
f97c475
665d3e9
de7dbb0
 
665d3e9
de7dbb0
 
 
665d3e9
 
 
 
de7dbb0
 
 
665d3e9
 
 
 
f97c475
 
 
de7dbb0
 
f97c475
de7dbb0
f97c475
 
de7dbb0
 
665d3e9
f97c475
 
665d3e9
f97c475
 
 
665d3e9
f97c475
665d3e9
f97c475
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d0ba1f
de7dbb0
f97c475
 
 
 
 
 
 
 
 
 
 
 
 
 
de7dbb0
 
 
f97c475
 
 
de7dbb0
119dbbc
f97c475
665d3e9
 
 
cc2745c
 
 
 
 
 
665d3e9
cc2745c
 
 
 
 
 
1d0ba1f
de7dbb0
 
 
 
 
 
 
f97c475
de7dbb0
 
 
 
 
 
 
 
 
 
f97c475
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForMaskedLM
import torch
import logging

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="NeuroBERT-Tiny Masked Language Model API",
    description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
    version="1.0.0"
)

api_router = APIRouter()

# --- NeuroBERT-Tiny Model Configuration ---
# Using boltuix/NeuroBERT-Tiny for Masked Language Modeling.
MODEL_NAME = "boltuix/NeuroBERT-Tiny"
# ----------------------------------------

# Load model globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
    logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
    model.eval() # Set model to evaluation mode
    logger.info("Model loaded successfully.")
except Exception as e:
    logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
    raise RuntimeError(f"Could not load model: {e}")

class InferenceRequest(BaseModel):
    """
    Request model for the /predict endpoint.
    Expects a single string field 'text' containing the sentence with [MASK] tokens.
    """
    text: str

class PredictionResult(BaseModel):
    """
    Response model for individual predictions from the /predict endpoint.
    """
    sequence: str    # The full sequence with the predicted token filled in
    score: float     # Confidence score of the prediction
    token: int       # The ID of the predicted token
    token_str: str   # The string representation of the predicted token

@api_router.post(
    "/predict", # Prediction endpoint
    response_model=list[PredictionResult],
    summary="Predicts masked tokens in a given text using NeuroBERT-Tiny",
    description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
)
async def predict_masked_lm(request: InferenceRequest):
    """
    Predicts the most likely tokens for [MASK] positions in the input text using the NeuroBERT-Tiny model.
    Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
    """
    try:
        text = request.text
        logger.info(f"Received prediction request for text: '{text}'")

        inputs = tokenizer(text, return_tensors="pt")

        with torch.no_grad():
            outputs = model(**inputs)

        logits = outputs.logits
        masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")

        # Find all masked tokens
        masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]

        if not masked_token_indices.numel():
            logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
            raise HTTPException(
                status_code=status.HTTP_400_BAD_REQUEST,
                detail="Input text must contain at least one '[MASK]' token."
            )

        results = []
        for masked_index in masked_token_indices:
            # Get top 5 predictions for the masked token
            top_5_logits = torch.topk(logits[0, masked_index], 5).values
            top_5_tokens = torch.topk(logits[0, masked_index], 5).indices

            for i in range(5):
                score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
                predicted_token_id = top_5_tokens[i].item()
                predicted_token_str = tokenizer.decode(predicted_token_id)

                # Replace the [MASK] with the predicted token for the full sequence
                temp_input_ids = inputs["input_ids"].clone()
                temp_input_ids[0, masked_index] = predicted_token_id
                full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)

                results.append(PredictionResult(
                    sequence=full_sequence,
                    score=score,
                    token=predicted_token_id,
                    token_str=predicted_token_str
                ))
        
        logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
        return results
    
    except ValidationError as e:
        logger.error(f"Validation error for request: {e.errors()}")
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail=e.errors()
        )
    except HTTPException:
        raise # Re-raise custom HTTPExceptions
    except Exception as e:
        logger.exception(f"An unexpected error occurred during prediction: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"An internal server error occurred: {e}"
        )

@api_router.get(
    "/health", # Health check endpoint
    summary="Health Check",
    description="Returns a simple message indicating the API is running."
)
async def health_check():
    logger.info("Health check endpoint accessed.")
    return {"message": "NeuroBERT-Tiny API is running!"}

app.include_router(api_router)

@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
    logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
    raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")