from fastapi import FastAPI, HTTPException, status, APIRouter, Request from pydantic import BaseModel, ValidationError from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline import torch import logging import asyncio # For running synchronous model inference in a separate thread logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) app = FastAPI( title="Masked Language Model API (via TinyLlama)", description="An API to perform Masked Language Modeling using a locally hosted TinyLlama model.", version="1.0.0" ) api_router = APIRouter() # --- TinyLlama Model Configuration --- # Using TinyLlama-1.1B-Chat-v1.0 which is a small, Llama-like model suitable for local inference. MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" # ----------------------------------- # Load model and tokenizer globally to avoid reloading on each request # This block runs once when the FastAPI application starts. try: logger.info(f"Loading tokenizer and model for {MODEL_NAME}...") # Load tokenizer and model for Causal LM (text generation) tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME) # Using torch_dtype=torch.bfloat16 for potential memory/speed benefits if GPU is available # and to fit within common memory limits. Also using device_map="auto" to load efficiently. model = AutoModelForCausalLM.from_pretrained( MODEL_NAME, torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, device_map="auto" ) model.eval() # Set model to evaluation mode # Create a text generation pipeline # We will adjust this pipeline's behavior in predict_masked_lm # to simulate masked LM functionality by prompting the LLM. text_generator = pipeline( "text-generation", model=model, tokenizer=tokenizer, # Ensure pad_token_id is set if tokenizer does not have one, to avoid warnings/errors pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id ) logger.info("Model loaded successfully.") except Exception as e: logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!") raise RuntimeError(f"Could not load model: {e}") class InferenceRequest(BaseModel): """ Request model for the /predict endpoint. Expects a single string field 'text' containing the sentence with [MASK] tokens. """ text: str class PredictionResult(BaseModel): """ Response model for individual predictions from the /predict endpoint. Simplified to focus on the sequence and score, abstracting token details. """ sequence: str # The full sequence with the predicted token filled in score: float # Confidence score of the prediction (approximated for generative LLMs) async def run_inference_blocking(generator_pipeline, prompt, num_return_sequences=5): """ Runs the synchronous model inference in a separate thread to avoid blocking FastAPI's event loop. """ return generator_pipeline( prompt, max_new_tokens=10, # Generate a small number of tokens for the mask num_return_sequences=num_return_sequences, do_sample=True, # Enable sampling for varied predictions temperature=0.7, # Control randomness top_k=50, # Consider top K tokens for sampling top_p=0.95, # Consider tokens up to a certain cumulative probability # The stop_sequence ensures it doesn't generate too much beyond the expected word stop_sequence=[" ", ".", ",", "!", "?", "\n"] # Stop after generating a word/punctuation ) @api_router.post( "/predict", # Prediction endpoint remains /predict response_model=list[PredictionResult], summary="Predicts masked tokens in a given text using a local TinyLlama model", description="Accepts a text string with '[MASK]' tokens and returns up to 5 single-word predictions for each masked position using a local generative AI model. Output is simplified to sequences and scores." ) async def predict_masked_lm(request: InferenceRequest): """ Predicts the most likely tokens for [MASK] positions in the input text using the TinyLlama model. Returns a list of top predictions for each masked token, including the full sequence and an approximated score. """ text = request.text logger.info(f"Received prediction request for text: '{text}'") if "[MASK]" not in text: logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.") raise HTTPException( status_code=status.HTTP_400_BAD_REQUEST, detail="Input text must contain at least one '[MASK]' token." ) # Find the position of the first [MASK] token to correctly prompt the LLM # And to insert predictions back into the original text mask_start_index = text.find("[MASK]") if mask_start_index == -1: # Should already be caught above, but as a safeguard raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No '[MASK]' token found in input.") # Craft a prompt that encourages the LLM to fill the mask. # The prompt guides the generative LLM to act like a fill-mask model. # Example: "The quick brown fox jumps over the [MASK] dog. The word that should replace [MASK] is:" # We remove "[MASK]" from the prompt for the generative model, and then # prepend a guiding phrase and append the text after the mask. # Split text around the first [MASK] parts = text.split("[MASK]", 1) if len(parts) < 2: # Should not happen if [MASK] is found raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing mask position.") pre_mask_text = parts[0].strip() post_mask_text = parts[1].strip() # Construct the prompt to guide TinyLlama # "Fill in the blank: 'The quick brown fox jumps over the ______ dog.' Best options are:" prompt = f"Complete the missing word in the following sentence. Give 5 single-word options. Sentence: '{pre_mask_text} ____ {post_mask_text}' Options:" try: # Run inference in a separate thread to not block the main event loop # The model's output will be a list of dicts, e.g., [{"generated_text": "Prompt + predicted word"}] raw_predictions = await run_inference_blocking(text_generator, prompt) results = [] seen_words = set() # To ensure unique predictions for i, pred_item in enumerate(raw_predictions): generated_text = pred_item.get("generated_text", "") # Extract only the predicted word from the generated text # This is heuristic and might need fine-tuning based on actual model output # We look for text that comes *after* our prompt and try to extract the first word. if prompt in generated_text: completion_text = generated_text.split(prompt, 1)[-1].strip() # Try to extract the first word if it contains spaces predicted_word = completion_text.split(' ', 1)[0].strip().replace('.', '').replace(',', '') # Filter out numbers, common filler words, or very short non-alpha words if not predicted_word.isalpha() or len(predicted_word) < 2: continue # Further refine by splitting on common word separators, taking the first valid word valid_words = [w for w in predicted_word.split() if w.isalpha() and len(w) > 1] if not valid_words: continue predicted_word = valid_words[0].lower() # Normalize to lowercase # Ensure unique predictions if predicted_word in seen_words: continue seen_words.add(predicted_word) # Construct the full sequence with the predicted word full_sequence = text.replace("[MASK]", predicted_word, 1) # Approximate score (generative LLMs don't give scores directly for words) mock_score = 0.95 - (i * 0.01) # Slightly decrease confidence for lower ranks results.append(PredictionResult( sequence=full_sequence, score=mock_score )) if len(results) >= 5: # Stop after getting 5 valid results break if not results: logger.warning("No valid predictions could be formatted from LLM response.") raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not extract predictions from TinyLlama output.") logger.info(f"Successfully processed request via TinyLlama. Returning {len(results)} predictions.") return results except ValidationError as e: logger.error(f"Validation error for request: {e.errors()}") raise HTTPException( status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, detail=e.errors() ) except HTTPException: raise # Re-raise custom HTTPExceptions except Exception as e: logger.exception(f"An unexpected error occurred during prediction: {e}") raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"An internal server error occurred: {e}" ) @api_router.get( "/health", # Health check endpoint summary="Health Check", description="Returns a simple message indicating the API is running." ) async def health_check(): logger.info("Health check endpoint accessed.") return {"message": "Masked Language Model API (via TinyLlama) is running!"} app.include_router(api_router) @app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"]) async def catch_all(request: Request, path_name: str): logger.warning(f"Unhandled route accessed: {request.method} {path_name}") raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found") if __name__ == "__main__": import uvicorn uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")