faq / app.py
brendon-ai's picture
Update app.py
176e257 verified
raw
history blame
10.3 kB
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import logging
import asyncio # For running synchronous model inference in a separate thread
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Masked Language Model API (via TinyLlama)",
description="An API to perform Masked Language Modeling using a locally hosted TinyLlama model.",
version="1.0.0"
)
api_router = APIRouter()
# --- TinyLlama Model Configuration ---
# Using TinyLlama-1.1B-Chat-v1.0 which is a small, Llama-like model suitable for local inference.
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# -----------------------------------
# Load model and tokenizer globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
# Load tokenizer and model for Causal LM (text generation)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Using torch_dtype=torch.bfloat16 for potential memory/speed benefits if GPU is available
# and to fit within common memory limits. Also using device_map="auto" to load efficiently.
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
model.eval() # Set model to evaluation mode
# Create a text generation pipeline
# We will adjust this pipeline's behavior in predict_masked_lm
# to simulate masked LM functionality by prompting the LLM.
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
# Ensure pad_token_id is set if tokenizer does not have one, to avoid warnings/errors
pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
)
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
"""
Request model for the /predict endpoint.
Expects a single string field 'text' containing the sentence with [MASK] tokens.
"""
text: str
class PredictionResult(BaseModel):
"""
Response model for individual predictions from the /predict endpoint.
Simplified to focus on the sequence and score, abstracting token details.
"""
sequence: str # The full sequence with the predicted token filled in
score: float # Confidence score of the prediction (approximated for generative LLMs)
async def run_inference_blocking(generator_pipeline, prompt, num_return_sequences=5):
"""
Runs the synchronous model inference in a separate thread to avoid blocking FastAPI's event loop.
"""
return generator_pipeline(
prompt,
max_new_tokens=10, # Generate a small number of tokens for the mask
num_return_sequences=num_return_sequences,
do_sample=True, # Enable sampling for varied predictions
temperature=0.7, # Control randomness
top_k=50, # Consider top K tokens for sampling
top_p=0.95, # Consider tokens up to a certain cumulative probability
# The stop_sequence ensures it doesn't generate too much beyond the expected word
stop_sequence=[" ", ".", ",", "!", "?", "\n"] # Stop after generating a word/punctuation
)
@api_router.post(
"/predict", # Prediction endpoint remains /predict
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text using a local TinyLlama model",
description="Accepts a text string with '[MASK]' tokens and returns up to 5 single-word predictions for each masked position using a local generative AI model. Output is simplified to sequences and scores."
)
async def predict_masked_lm(request: InferenceRequest):
"""
Predicts the most likely tokens for [MASK] positions in the input text using the TinyLlama model.
Returns a list of top predictions for each masked token, including the full sequence and an approximated score.
"""
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
if "[MASK]" not in text:
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
# Find the position of the first [MASK] token to correctly prompt the LLM
# And to insert predictions back into the original text
mask_start_index = text.find("[MASK]")
if mask_start_index == -1: # Should already be caught above, but as a safeguard
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No '[MASK]' token found in input.")
# Craft a prompt that encourages the LLM to fill the mask.
# The prompt guides the generative LLM to act like a fill-mask model.
# Example: "The quick brown fox jumps over the [MASK] dog. The word that should replace [MASK] is:"
# We remove "[MASK]" from the prompt for the generative model, and then
# prepend a guiding phrase and append the text after the mask.
# Split text around the first [MASK]
parts = text.split("[MASK]", 1)
if len(parts) < 2: # Should not happen if [MASK] is found
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing mask position.")
pre_mask_text = parts[0].strip()
post_mask_text = parts[1].strip()
# Construct the prompt to guide TinyLlama
# "Fill in the blank: 'The quick brown fox jumps over the ______ dog.' Best options are:"
prompt = f"Complete the missing word in the following sentence. Give 5 single-word options. Sentence: '{pre_mask_text} ____ {post_mask_text}' Options:"
try:
# Run inference in a separate thread to not block the main event loop
# The model's output will be a list of dicts, e.g., [{"generated_text": "Prompt + predicted word"}]
raw_predictions = await run_inference_blocking(text_generator, prompt)
results = []
seen_words = set() # To ensure unique predictions
for i, pred_item in enumerate(raw_predictions):
generated_text = pred_item.get("generated_text", "")
# Extract only the predicted word from the generated text
# This is heuristic and might need fine-tuning based on actual model output
# We look for text that comes *after* our prompt and try to extract the first word.
if prompt in generated_text:
completion_text = generated_text.split(prompt, 1)[-1].strip()
# Try to extract the first word if it contains spaces
predicted_word = completion_text.split(' ', 1)[0].strip().replace('.', '').replace(',', '')
# Filter out numbers, common filler words, or very short non-alpha words
if not predicted_word.isalpha() or len(predicted_word) < 2:
continue
# Further refine by splitting on common word separators, taking the first valid word
valid_words = [w for w in predicted_word.split() if w.isalpha() and len(w) > 1]
if not valid_words: continue
predicted_word = valid_words[0].lower() # Normalize to lowercase
# Ensure unique predictions
if predicted_word in seen_words:
continue
seen_words.add(predicted_word)
# Construct the full sequence with the predicted word
full_sequence = text.replace("[MASK]", predicted_word, 1)
# Approximate score (generative LLMs don't give scores directly for words)
mock_score = 0.95 - (i * 0.01) # Slightly decrease confidence for lower ranks
results.append(PredictionResult(
sequence=full_sequence,
score=mock_score
))
if len(results) >= 5: # Stop after getting 5 valid results
break
if not results:
logger.warning("No valid predictions could be formatted from LLM response.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not extract predictions from TinyLlama output.")
logger.info(f"Successfully processed request via TinyLlama. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
raise # Re-raise custom HTTPExceptions
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@api_router.get(
"/health", # Health check endpoint
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def health_check():
logger.info("Health check endpoint accessed.")
return {"message": "Masked Language Model API (via TinyLlama) is running!"}
app.include_router(api_router)
@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")