brendon-ai commited on
Commit
f97c475
·
verified ·
1 Parent(s): cc98c16

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +56 -130
app.py CHANGED
@@ -1,51 +1,32 @@
1
  from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
  from pydantic import BaseModel, ValidationError
3
- from transformers import AutoTokenizer, AutoModelForCausalLM, pipeline
4
  import torch
5
  import logging
6
- import asyncio # For running synchronous model inference in a separate thread
7
 
8
  logging.basicConfig(level=logging.INFO)
9
  logger = logging.getLogger(__name__)
10
 
11
  app = FastAPI(
12
- title="Masked Language Model API (via TinyLlama)",
13
- description="An API to perform Masked Language Modeling using a locally hosted TinyLlama model.",
14
  version="1.0.0"
15
  )
16
 
17
  api_router = APIRouter()
18
 
19
- # --- TinyLlama Model Configuration ---
20
- # Using TinyLlama-1.1B-Chat-v1.0 which is a small, Llama-like model suitable for local inference.
21
- MODEL_NAME = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
22
- # -----------------------------------
23
 
24
- # Load model and tokenizer globally to avoid reloading on each request
25
  # This block runs once when the FastAPI application starts.
26
  try:
27
  logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
28
- # Load tokenizer and model for Causal LM (text generation)
29
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
30
- # Using torch_dtype=torch.bfloat16 for potential memory/speed benefits if GPU is available
31
- # and to fit within common memory limits. Also using device_map="auto" to load efficiently.
32
- model = AutoModelForCausalLM.from_pretrained(
33
- MODEL_NAME,
34
- torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32,
35
- device_map="auto"
36
- )
37
  model.eval() # Set model to evaluation mode
38
-
39
- # Create a text generation pipeline
40
- # We will adjust this pipeline's behavior in predict_masked_lm
41
- # to simulate masked LM functionality by prompting the LLM.
42
- text_generator = pipeline(
43
- "text-generation",
44
- model=model,
45
- tokenizer=tokenizer,
46
- # Ensure pad_token_id is set if tokenizer does not have one, to avoid warnings/errors
47
- pad_token_id=tokenizer.eos_token_id if tokenizer.pad_token_id is None else tokenizer.pad_token_id
48
- )
49
  logger.info("Model loaded successfully.")
50
  except Exception as e:
51
  logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
@@ -61,124 +42,69 @@ class InferenceRequest(BaseModel):
61
  class PredictionResult(BaseModel):
62
  """
63
  Response model for individual predictions from the /predict endpoint.
64
- Simplified to focus on the sequence and score, abstracting token details.
65
  """
66
  sequence: str # The full sequence with the predicted token filled in
67
- score: float # Confidence score of the prediction (approximated for generative LLMs)
68
-
69
- async def run_inference_blocking(generator_pipeline, prompt, num_return_sequences=5):
70
- """
71
- Runs the synchronous model inference in a separate thread to avoid blocking FastAPI's event loop.
72
- """
73
- return generator_pipeline(
74
- prompt,
75
- max_new_tokens=10, # Generate a small number of tokens for the mask
76
- num_return_sequences=num_return_sequences,
77
- do_sample=True, # Enable sampling for varied predictions
78
- temperature=0.7, # Control randomness
79
- top_k=50, # Consider top K tokens for sampling
80
- top_p=0.95, # Consider tokens up to a certain cumulative probability
81
- # The stop_sequence ensures it doesn't generate too much beyond the expected word
82
- stop_sequence=[" ", ".", ",", "!", "?", "\n"] # Stop after generating a word/punctuation
83
- )
84
-
85
 
86
  @api_router.post(
87
- "/predict", # Prediction endpoint remains /predict
88
  response_model=list[PredictionResult],
89
- summary="Predicts masked tokens in a given text using a local TinyLlama model",
90
- description="Accepts a text string with '[MASK]' tokens and returns up to 5 single-word predictions for each masked position using a local generative AI model. Output is simplified to sequences and scores."
91
  )
92
  async def predict_masked_lm(request: InferenceRequest):
93
  """
94
- Predicts the most likely tokens for [MASK] positions in the input text using the TinyLlama model.
95
- Returns a list of top predictions for each masked token, including the full sequence and an approximated score.
96
  """
97
- text = request.text
98
- logger.info(f"Received prediction request for text: '{text}'")
 
99
 
100
- if "[MASK]" not in text:
101
- logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
102
- raise HTTPException(
103
- status_code=status.HTTP_400_BAD_REQUEST,
104
- detail="Input text must contain at least one '[MASK]' token."
105
- )
106
-
107
- # Find the position of the first [MASK] token to correctly prompt the LLM
108
- # And to insert predictions back into the original text
109
- mask_start_index = text.find("[MASK]")
110
- if mask_start_index == -1: # Should already be caught above, but as a safeguard
111
- raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="No '[MASK]' token found in input.")
112
-
113
- # Craft a prompt that encourages the LLM to fill the mask.
114
- # The prompt guides the generative LLM to act like a fill-mask model.
115
- # Example: "The quick brown fox jumps over the [MASK] dog. The word that should replace [MASK] is:"
116
- # We remove "[MASK]" from the prompt for the generative model, and then
117
- # prepend a guiding phrase and append the text after the mask.
118
-
119
- # Split text around the first [MASK]
120
- parts = text.split("[MASK]", 1)
121
- if len(parts) < 2: # Should not happen if [MASK] is found
122
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Error processing mask position.")
123
-
124
- pre_mask_text = parts[0].strip()
125
- post_mask_text = parts[1].strip()
126
 
127
- # Construct the prompt to guide TinyLlama
128
- # "Fill in the blank: 'The quick brown fox jumps over the ______ dog.' Best options are:"
129
- prompt = f"Complete the missing word in the following sentence. Give 5 single-word options. Sentence: '{pre_mask_text} ____ {post_mask_text}' Options:"
 
 
 
 
 
 
 
 
 
 
 
 
130
 
131
- try:
132
- # Run inference in a separate thread to not block the main event loop
133
- # The model's output will be a list of dicts, e.g., [{"generated_text": "Prompt + predicted word"}]
134
- raw_predictions = await run_inference_blocking(text_generator, prompt)
135
-
136
  results = []
137
- seen_words = set() # To ensure unique predictions
138
-
139
- for i, pred_item in enumerate(raw_predictions):
140
- generated_text = pred_item.get("generated_text", "")
141
-
142
- # Extract only the predicted word from the generated text
143
- # This is heuristic and might need fine-tuning based on actual model output
144
- # We look for text that comes *after* our prompt and try to extract the first word.
145
- if prompt in generated_text:
146
- completion_text = generated_text.split(prompt, 1)[-1].strip()
147
- # Try to extract the first word if it contains spaces
148
- predicted_word = completion_text.split(' ', 1)[0].strip().replace('.', '').replace(',', '')
149
- # Filter out numbers, common filler words, or very short non-alpha words
150
- if not predicted_word.isalpha() or len(predicted_word) < 2:
151
- continue
152
-
153
- # Further refine by splitting on common word separators, taking the first valid word
154
- valid_words = [w for w in predicted_word.split() if w.isalpha() and len(w) > 1]
155
- if not valid_words: continue
156
- predicted_word = valid_words[0].lower() # Normalize to lowercase
157
-
158
- # Ensure unique predictions
159
- if predicted_word in seen_words:
160
- continue
161
- seen_words.add(predicted_word)
162
-
163
- # Construct the full sequence with the predicted word
164
- full_sequence = text.replace("[MASK]", predicted_word, 1)
165
-
166
- # Approximate score (generative LLMs don't give scores directly for words)
167
- mock_score = 0.95 - (i * 0.01) # Slightly decrease confidence for lower ranks
168
 
169
  results.append(PredictionResult(
170
  sequence=full_sequence,
171
- score=mock_score
 
 
172
  ))
173
-
174
- if len(results) >= 5: # Stop after getting 5 valid results
175
- break
176
 
177
- if not results:
178
- logger.warning("No valid predictions could be formatted from LLM response.")
179
- raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Could not extract predictions from TinyLlama output.")
180
-
181
- logger.info(f"Successfully processed request via TinyLlama. Returning {len(results)} predictions.")
182
  return results
183
 
184
  except ValidationError as e:
@@ -203,7 +129,7 @@ async def predict_masked_lm(request: InferenceRequest):
203
  )
204
  async def health_check():
205
  logger.info("Health check endpoint accessed.")
206
- return {"message": "Masked Language Model API (via TinyLlama) is running!"}
207
 
208
  app.include_router(api_router)
209
 
@@ -214,4 +140,4 @@ async def catch_all(request: Request, path_name: str):
214
 
215
  if __name__ == "__main__":
216
  import uvicorn
217
- uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")
 
1
  from fastapi import FastAPI, HTTPException, status, APIRouter, Request
2
  from pydantic import BaseModel, ValidationError
3
+ from transformers import AutoTokenizer, AutoModelForMaskedLM
4
  import torch
5
  import logging
 
6
 
7
  logging.basicConfig(level=logging.INFO)
8
  logger = logging.getLogger(__name__)
9
 
10
  app = FastAPI(
11
+ title="NeuroBERT-Tiny Masked Language Model API",
12
+ description="An API to perform Masked Language Modeling using the boltuix/NeuroBERT-Tiny model.",
13
  version="1.0.0"
14
  )
15
 
16
  api_router = APIRouter()
17
 
18
+ # --- NeuroBERT-Tiny Model Configuration ---
19
+ # Using boltuix/NeuroBERT-Tiny for Masked Language Modeling.
20
+ MODEL_NAME = "boltuix/NeuroBERT-Tiny"
21
+ # ----------------------------------------
22
 
23
+ # Load model globally to avoid reloading on each request
24
  # This block runs once when the FastAPI application starts.
25
  try:
26
  logger.info(f"Loading tokenizer and model for {MODEL_NAME}...")
 
27
  tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
28
+ model = AutoModelForMaskedLM.from_pretrained(MODEL_NAME)
 
 
 
 
 
 
29
  model.eval() # Set model to evaluation mode
 
 
 
 
 
 
 
 
 
 
 
30
  logger.info("Model loaded successfully.")
31
  except Exception as e:
32
  logger.exception(f"Failed to load model or tokenizer for {MODEL_NAME} during startup!")
 
42
  class PredictionResult(BaseModel):
43
  """
44
  Response model for individual predictions from the /predict endpoint.
 
45
  """
46
  sequence: str # The full sequence with the predicted token filled in
47
+ score: float # Confidence score of the prediction
48
+ token: int # The ID of the predicted token
49
+ token_str: str # The string representation of the predicted token
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
  @api_router.post(
52
+ "/predict", # Prediction endpoint
53
  response_model=list[PredictionResult],
54
+ summary="Predicts masked tokens in a given text using NeuroBERT-Tiny",
55
+ description="Accepts a text string with '[MASK]' tokens and returns top 5 predictions for each masked position."
56
  )
57
  async def predict_masked_lm(request: InferenceRequest):
58
  """
59
+ Predicts the most likely tokens for [MASK] positions in the input text using the NeuroBERT-Tiny model.
60
+ Returns a list of top 5 predictions for each masked token, including the full sequence, score, and token details.
61
  """
62
+ try:
63
+ text = request.text
64
+ logger.info(f"Received prediction request for text: '{text}'")
65
 
66
+ inputs = tokenizer(text, return_tensors="pt")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
+ with torch.no_grad():
69
+ outputs = model(**inputs)
70
+
71
+ logits = outputs.logits
72
+ masked_token_id = tokenizer.convert_tokens_to_ids("[MASK]")
73
+
74
+ # Find all masked tokens
75
+ masked_token_indices = torch.where(inputs["input_ids"] == masked_token_id)[1]
76
+
77
+ if not masked_token_indices.numel():
78
+ logger.warning("No [MASK] token found in the input text. Returning 400 Bad Request.")
79
+ raise HTTPException(
80
+ status_code=status.HTTP_400_BAD_REQUEST,
81
+ detail="Input text must contain at least one '[MASK]' token."
82
+ )
83
 
 
 
 
 
 
84
  results = []
85
+ for masked_index in masked_token_indices:
86
+ # Get top 5 predictions for the masked token
87
+ top_5_logits = torch.topk(logits[0, masked_index], 5).values
88
+ top_5_tokens = torch.topk(logits[0, masked_index], 5).indices
89
+
90
+ for i in range(5):
91
+ score = torch.nn.functional.softmax(logits[0, masked_index], dim=-1)[top_5_tokens[i]].item()
92
+ predicted_token_id = top_5_tokens[i].item()
93
+ predicted_token_str = tokenizer.decode(predicted_token_id)
94
+
95
+ # Replace the [MASK] with the predicted token for the full sequence
96
+ temp_input_ids = inputs["input_ids"].clone()
97
+ temp_input_ids[0, masked_index] = predicted_token_id
98
+ full_sequence = tokenizer.decode(temp_input_ids[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
99
 
100
  results.append(PredictionResult(
101
  sequence=full_sequence,
102
+ score=score,
103
+ token=predicted_token_id,
104
+ token_str=predicted_token_str
105
  ))
 
 
 
106
 
107
+ logger.info(f"Successfully processed request. Returning {len(results)} predictions.")
 
 
 
 
108
  return results
109
 
110
  except ValidationError as e:
 
129
  )
130
  async def health_check():
131
  logger.info("Health check endpoint accessed.")
132
+ return {"message": "NeuroBERT-Tiny API is running!"}
133
 
134
  app.include_router(api_router)
135
 
 
140
 
141
  if __name__ == "__main__":
142
  import uvicorn
143
+ uvicorn.run(app, host="0.0.0.0", port=7860, log_level="info")