File size: 10,304 Bytes
de7dbb0
 
665d3e9
de7dbb0
 
665d3e9
de7dbb0
 
 
 
 
665d3e9
 
de7dbb0
 
 
 
 
665d3e9
 
 
 
 
 
 
de7dbb0
665d3e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7dbb0
 
665d3e9
de7dbb0
 
 
665d3e9
 
 
 
de7dbb0
 
 
665d3e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7dbb0
 
665d3e9
de7dbb0
665d3e9
 
de7dbb0
 
665d3e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1d0ba1f
665d3e9
 
 
 
 
de7dbb0
665d3e9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
de7dbb0
665d3e9
 
de7dbb0
 
 
665d3e9
de7dbb0
119dbbc
665d3e9
 
119dbbc
665d3e9
 
 
 
 
 
 
 
cc2745c
 
 
 
 
 
665d3e9
cc2745c
 
 
 
 
 
1d0ba1f
de7dbb0
 
 
 
 
 
 
665d3e9
de7dbb0
 
 
 
 
 
 
 
 
 
176e257
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import logging
import asyncio # For running synchronous model inference in a separate thread

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)

app = FastAPI(
    title="Masked Language Model API (via TinyLlama)",
    description="An API to perform Masked Language Modeling using a locally hosted TinyLlama model.",
    version="1.0.0"
)

api_router = APIRouter()

# --- TinyLlama Model Configuration ---
# Using TinyLlama-1.1B-Chat-v1.0 which is a small, Llama-like model suitable for local inference.
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# -----------------------------------

# Load model and tokenizer globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
    logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
    # Load tokenizer and model for Causal LM (text generation)
    tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
    # Using torch_dtype=torch.bfloat16 for potential memory/speed benefits if GPU is available
    # and to fit within common memory limits. Also using device_map="auto" to load efficiently.
    model = AutoModelForCausalLM.from_pretrained(
        MODEL_NAME, 
        torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, 
        device_map="auto"
    )
    model.eval() # Set model to evaluation mode
    
    # Create a text generation pipeline
    # We will adjust this pipeline's behavior in predict_masked_lm
    # to simulate masked LM functionality by prompting the LLM.
    text_generator = pipeline(
        "text-generation", 
        model=model, 
        tokenizer=tokenizer,
        # Ensure pad_token_id is set if tokenizer does not have one, to avoid warnings/errors
        pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
    )
    logger.info("Model loaded successfully.")
except Exception as e:
    logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
    raise RuntimeError(f"Could not load model: {e}")

class InferenceRequest(BaseModel):
    """
    Request model for the /predict endpoint.
    Expects a single string field 'text' containing the sentence with [MASK] tokens.
    """
    text: str

class PredictionResult(BaseModel):
    """
    Response model for individual predictions from the /predict endpoint.
    Simplified to focus on the sequence and score, abstracting token details.
    """
    sequence: str    # The full sequence with the predicted token filled in
    score: float     # Confidence score of the prediction (approximated for generative LLMs)

async def run_inference_blocking(generator_pipeline, prompt, num_return_sequences=5):
    """
    Runs the synchronous model inference in a separate thread to avoid blocking FastAPI's event loop.
    """
    return generator_pipeline(
        prompt,
        max_new_tokens=10, # Generate a small number of tokens for the mask
        num_return_sequences=num_return_sequences,
        do_sample=True,    # Enable sampling for varied predictions
        temperature=0.7,   # Control randomness
        top_k=50,          # Consider top K tokens for sampling
        top_p=0.95,        # Consider tokens up to a certain cumulative probability
        # The stop_sequence ensures it doesn't generate too much beyond the expected word
        stop_sequence=[" ", ".", ",", "!", "?", "\n"] # Stop after generating a word/punctuation
    )


@api_router.post(
    "/predict", # Prediction endpoint remains /predict
    response_model=list[PredictionResult],
    summary="Predicts masked tokens in a given text using a local TinyLlama model",
    description="Accepts a text string with '[MASK]' tokens and returns up to 5 single-word predictions for each masked position using a local generative AI model. Output is simplified to sequences and scores."
)
async def predict_masked_lm(request: InferenceRequest):
    """
    Predicts the most likely tokens for [MASK] positions in the input text using the TinyLlama model.
    Returns a list of top predictions for each masked token, including the full sequence and an approximated score.
    """
    text = request.text
    logger.info(f"Received prediction request for text: '{text}'")

    if "[MASK]" not in text:
        logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
        raise HTTPException(
            status_code=status.HTTP_400_BAD_REQUEST,
            detail="Input text must contain at least one '[MASK]' token."
        )
    
    # Find the position of the first [MASK] token to correctly prompt the LLM
    # And to insert predictions back into the original text
    mask_start_index = text.find("[MASK]")
    if mask_start_index == -1: # Should already be caught above, but as a safeguard
        raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No '[MASK]' token found in input.")

    # Craft a prompt that encourages the LLM to fill the mask.
    # The prompt guides the generative LLM to act like a fill-mask model.
    # Example: "The quick brown fox jumps over the [MASK] dog. The word that should replace [MASK] is:"
    # We remove "[MASK]" from the prompt for the generative model, and then
    # prepend a guiding phrase and append the text after the mask.
    
    # Split text around the first [MASK]
    parts = text.split("[MASK]", 1)
    if len(parts) < 2: # Should not happen if [MASK] is found
        raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing mask position.")
    
    pre_mask_text = parts[0].strip()
    post_mask_text = parts[1].strip()

    # Construct the prompt to guide TinyLlama
    # "Fill in the blank: 'The quick brown fox jumps over the ______ dog.' Best options are:"
    prompt = f"Complete the missing word in the following sentence. Give 5 single-word options. Sentence: '{pre_mask_text} ____ {post_mask_text}' Options:"

    try:
        # Run inference in a separate thread to not block the main event loop
        # The model's output will be a list of dicts, e.g., [{"generated_text": "Prompt + predicted word"}]
        raw_predictions = await run_inference_blocking(text_generator, prompt)
        
        results = []
        seen_words = set() # To ensure unique predictions

        for i, pred_item in enumerate(raw_predictions):
            generated_text = pred_item.get("generated_text", "")
            
            # Extract only the predicted word from the generated text
            # This is heuristic and might need fine-tuning based on actual model output
            # We look for text that comes *after* our prompt and try to extract the first word.
            if prompt in generated_text:
                completion_text = generated_text.split(prompt, 1)[-1].strip()
                # Try to extract the first word if it contains spaces
                predicted_word = completion_text.split(' ', 1)[0].strip().replace('.', '').replace(',', '')
                # Filter out numbers, common filler words, or very short non-alpha words
                if not predicted_word.isalpha() or len(predicted_word) < 2:
                    continue
                
                # Further refine by splitting on common word separators, taking the first valid word
                valid_words = [w for w in predicted_word.split() if w.isalpha() and len(w) > 1]
                if not valid_words: continue
                predicted_word = valid_words[0].lower() # Normalize to lowercase

                # Ensure unique predictions
                if predicted_word in seen_words:
                    continue
                seen_words.add(predicted_word)

                # Construct the full sequence with the predicted word
                full_sequence = text.replace("[MASK]", predicted_word, 1)
                
                # Approximate score (generative LLMs don't give scores directly for words)
                mock_score = 0.95 - (i * 0.01) # Slightly decrease confidence for lower ranks

                results.append(PredictionResult(
                    sequence=full_sequence,
                    score=mock_score
                ))
            
            if len(results) >= 5: # Stop after getting 5 valid results
                break
        
        if not results:
            logger.warning("No valid predictions could be formatted from LLM response.")
            raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not extract predictions from TinyLlama output.")

        logger.info(f"Successfully processed request via TinyLlama. Returning {len(results)} predictions.")
        return results
    
    except ValidationError as e:
        logger.error(f"Validation error for request: {e.errors()}")
        raise HTTPException(
            status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
            detail=e.errors()
        )
    except HTTPException:
        raise # Re-raise custom HTTPExceptions
    except Exception as e:
        logger.exception(f"An unexpected error occurred during prediction: {e}")
        raise HTTPException(
            status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
            detail=f"An internal server error occurred: {e}"
        )

@api_router.get(
    "/health", # Health check endpoint
    summary="Health Check",
    description="Returns a simple message indicating the API is running."
)
async def health_check():
    logger.info("Health check endpoint accessed.")
    return {"message": "Masked Language Model API (via TinyLlama) is running!"}

app.include_router(api_router)

@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
    logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
    raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")

if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")