Spaces:
Sleeping
Sleeping
File size: 10,304 Bytes
de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 1d0ba1f 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 665d3e9 de7dbb0 119dbbc 665d3e9 119dbbc 665d3e9 cc2745c 665d3e9 cc2745c 1d0ba1f de7dbb0 665d3e9 de7dbb0 176e257 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 |
from fastapi import FastAPI, HTTPException, status, APIRouter, Request
from pydantic import BaseModel, ValidationError
from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
import torch
import logging
import asyncio # For running synchronous model inference in a separate thread
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
app = FastAPI(
title="Masked Language Model API (via TinyLlama)",
description="An API to perform Masked Language Modeling using a locally hosted TinyLlama model.",
version="1.0.0"
)
api_router = APIRouter()
# --- TinyLlama Model Configuration ---
# Using TinyLlama-1.1B-Chat-v1.0 which is a small, Llama-like model suitable for local inference.
MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
# -----------------------------------
# Load model and tokenizer globally to avoid reloading on each request
# This block runs once when the FastAPI application starts.
try:
logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
# Load tokenizer and model for Causal LM (text generation)
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
# Using torch_dtype=torch.bfloat16 for potential memory/speed benefits if GPU is available
# and to fit within common memory limits. Also using device_map="auto" to load efficiently.
model = AutoModelForCausalLM.from_pretrained(
MODEL_NAME,
torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
device_map="auto"
)
model.eval() # Set model to evaluation mode
# Create a text generation pipeline
# We will adjust this pipeline's behavior in predict_masked_lm
# to simulate masked LM functionality by prompting the LLM.
text_generator = pipeline(
"text-generation",
model=model,
tokenizer=tokenizer,
# Ensure pad_token_id is set if tokenizer does not have one, to avoid warnings/errors
pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
)
logger.info("Model loaded successfully.")
except Exception as e:
logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
raise RuntimeError(f"Could not load model: {e}")
class InferenceRequest(BaseModel):
"""
Request model for the /predict endpoint.
Expects a single string field 'text' containing the sentence with [MASK] tokens.
"""
text: str
class PredictionResult(BaseModel):
"""
Response model for individual predictions from the /predict endpoint.
Simplified to focus on the sequence and score, abstracting token details.
"""
sequence: str # The full sequence with the predicted token filled in
score: float # Confidence score of the prediction (approximated for generative LLMs)
async def run_inference_blocking(generator_pipeline, prompt, num_return_sequences=5):
"""
Runs the synchronous model inference in a separate thread to avoid blocking FastAPI's event loop.
"""
return generator_pipeline(
prompt,
max_new_tokens=10, # Generate a small number of tokens for the mask
num_return_sequences=num_return_sequences,
do_sample=True, # Enable sampling for varied predictions
temperature=0.7, # Control randomness
top_k=50, # Consider top K tokens for sampling
top_p=0.95, # Consider tokens up to a certain cumulative probability
# The stop_sequence ensures it doesn't generate too much beyond the expected word
stop_sequence=[" ", ".", ",", "!", "?", "\n"] # Stop after generating a word/punctuation
)
@api_router.post(
"/predict", # Prediction endpoint remains /predict
response_model=list[PredictionResult],
summary="Predicts masked tokens in a given text using a local TinyLlama model",
description="Accepts a text string with '[MASK]' tokens and returns up to 5 single-word predictions for each masked position using a local generative AI model. Output is simplified to sequences and scores."
)
async def predict_masked_lm(request: InferenceRequest):
"""
Predicts the most likely tokens for [MASK] positions in the input text using the TinyLlama model.
Returns a list of top predictions for each masked token, including the full sequence and an approximated score.
"""
text = request.text
logger.info(f"Received prediction request for text: '{text}'")
if "[MASK]" not in text:
logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="Input text must contain at least one '[MASK]' token."
)
# Find the position of the first [MASK] token to correctly prompt the LLM
# And to insert predictions back into the original text
mask_start_index = text.find("[MASK]")
if mask_start_index == -1: # Should already be caught above, but as a safeguard
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No '[MASK]' token found in input.")
# Craft a prompt that encourages the LLM to fill the mask.
# The prompt guides the generative LLM to act like a fill-mask model.
# Example: "The quick brown fox jumps over the [MASK] dog. The word that should replace [MASK] is:"
# We remove "[MASK]" from the prompt for the generative model, and then
# prepend a guiding phrase and append the text after the mask.
# Split text around the first [MASK]
parts = text.split("[MASK]", 1)
if len(parts) < 2: # Should not happen if [MASK] is found
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing mask position.")
pre_mask_text = parts[0].strip()
post_mask_text = parts[1].strip()
# Construct the prompt to guide TinyLlama
# "Fill in the blank: 'The quick brown fox jumps over the ______ dog.' Best options are:"
prompt = f"Complete the missing word in the following sentence. Give 5 single-word options. Sentence: '{pre_mask_text} ____ {post_mask_text}' Options:"
try:
# Run inference in a separate thread to not block the main event loop
# The model's output will be a list of dicts, e.g., [{"generated_text": "Prompt + predicted word"}]
raw_predictions = await run_inference_blocking(text_generator, prompt)
results = []
seen_words = set() # To ensure unique predictions
for i, pred_item in enumerate(raw_predictions):
generated_text = pred_item.get("generated_text", "")
# Extract only the predicted word from the generated text
# This is heuristic and might need fine-tuning based on actual model output
# We look for text that comes *after* our prompt and try to extract the first word.
if prompt in generated_text:
completion_text = generated_text.split(prompt, 1)[-1].strip()
# Try to extract the first word if it contains spaces
predicted_word = completion_text.split(' ', 1)[0].strip().replace('.', '').replace(',', '')
# Filter out numbers, common filler words, or very short non-alpha words
if not predicted_word.isalpha() or len(predicted_word) < 2:
continue
# Further refine by splitting on common word separators, taking the first valid word
valid_words = [w for w in predicted_word.split() if w.isalpha() and len(w) > 1]
if not valid_words: continue
predicted_word = valid_words[0].lower() # Normalize to lowercase
# Ensure unique predictions
if predicted_word in seen_words:
continue
seen_words.add(predicted_word)
# Construct the full sequence with the predicted word
full_sequence = text.replace("[MASK]", predicted_word, 1)
# Approximate score (generative LLMs don't give scores directly for words)
mock_score = 0.95 - (i * 0.01) # Slightly decrease confidence for lower ranks
results.append(PredictionResult(
sequence=full_sequence,
score=mock_score
))
if len(results) >= 5: # Stop after getting 5 valid results
break
if not results:
logger.warning("No valid predictions could be formatted from LLM response.")
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not extract predictions from TinyLlama output.")
logger.info(f"Successfully processed request via TinyLlama. Returning {len(results)} predictions.")
return results
except ValidationError as e:
logger.error(f"Validation error for request: {e.errors()}")
raise HTTPException(
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
detail=e.errors()
)
except HTTPException:
raise # Re-raise custom HTTPExceptions
except Exception as e:
logger.exception(f"An unexpected error occurred during prediction: {e}")
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"An internal server error occurred: {e}"
)
@api_router.get(
"/health", # Health check endpoint
summary="Health Check",
description="Returns a simple message indicating the API is running."
)
async def health_check():
logger.info("Health check endpoint accessed.")
return {"message": "Masked Language Model API (via TinyLlama) is running!"}
app.include_router(api_router)
@app.api_route("/{path_name:path}", methods=["GET", "POST", "PUT", "DELETE", "PATCH", "OPTIONS", "HEAD"])
async def catch_all(request: Request, path_name: str):
logger.warning(f"Unhandled route accessed: {request.method} {path_name}")
raise HTTPException(status_code=status.HTTP_404_NOT_FOUND, detail="Not Found")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info") |