Abaryan commited on
Commit
30ca71a
·
verified ·
1 Parent(s): e8f34c8

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +138 -103
app.py CHANGED
@@ -2,14 +2,52 @@ from fastapi import FastAPI, HTTPException
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import torch
5
- from transformers import AutoModelForMultipleChoice, AutoTokenizer
6
  import os
7
  from datasets import load_dataset
8
  import random
9
- from typing import Optional, List
10
  import gradio as gr
 
11
 
12
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
13
 
14
  # Add CORS middleware for Gradio
15
  app.add_middleware(
@@ -34,82 +72,97 @@ class DatasetQuestion(BaseModel):
34
  cop: Optional[int] = None # Correct option (0-3)
35
  exp: Optional[str] = None # Explanation if available
36
 
37
- # Global variables
38
- model = None
39
- tokenizer = None
40
- dataset = None
 
 
 
41
 
42
- def load_model():
43
- global model, tokenizer, dataset
44
- try:
45
- # Load your fine-tuned model and tokenizer
46
- model_name = os.getenv("BioXP-0.5b", "rgb2gbr/GRPO_BioMedmcqa_Qwen2.5-0.5B")
47
- model = AutoModelForMultipleChoice.from_pretrained(model_name)
48
- tokenizer = AutoTokenizer.from_pretrained(model_name)
49
-
50
- # Load MedMCQA dataset
51
- dataset = load_dataset("openlifescienceai/medmcqa")
52
-
53
- # Move model to GPU if available
54
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
55
- model = model.to(device)
56
- model.eval()
57
- except Exception as e:
58
- raise Exception(f"Error loading model: {str(e)}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
61
  """Gradio interface prediction function"""
62
  try:
63
  options = [option_a, option_b, option_c, option_d]
64
- inputs = []
65
- for option in options:
66
- text = f"{question} {option}"
67
- inputs.append(text)
68
 
69
- encodings = tokenizer(
70
- inputs,
 
 
 
 
 
71
  padding=True,
72
  truncation=True,
73
- max_length=512,
74
- return_tensors="pt"
75
  )
76
 
77
  device = next(model.parameters()).device
78
- encodings = {k: v.to(device) for k, v in encodings.items()}
79
 
 
80
  with torch.no_grad():
81
- outputs = model(**encodings)
82
- logits = outputs.logits
83
- probabilities = torch.softmax(logits, dim=1)[0].tolist()
84
- predicted_class = torch.argmax(logits, dim=1).item()
 
 
 
 
 
 
 
 
 
 
85
 
86
  # Format the output for Gradio
87
- result = f"Predicted Answer: {options[predicted_class]}\n\n"
88
- result += "Confidence Scores:\n"
89
- for i, (opt, prob) in enumerate(zip(options, probabilities)):
90
- result += f"{opt}: {prob:.2%}\n"
91
 
92
  return result
93
 
94
  except Exception as e:
95
  return f"Error: {str(e)}"
96
 
97
- def get_random_question():
98
- """Get a random question for Gradio interface"""
99
- if dataset is None:
100
- return "Error: Dataset not loaded", "", "", "", ""
101
-
102
- index = random.randint(0, len(dataset['train']) - 1)
103
- question_data = dataset['train'][index]
104
-
105
- return (
106
- question_data['question'],
107
- question_data['opa'],
108
- question_data['opb'],
109
- question_data['opc'],
110
- question_data['opd']
111
- )
112
-
113
  # Create Gradio interface
114
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
115
  gr.Markdown("# Medical MCQ Predictor")
@@ -136,7 +189,7 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
136
  )
137
 
138
  random_btn.click(
139
- fn=get_random_question,
140
  inputs=[],
141
  outputs=[question, option_a, option_b, option_c, option_d]
142
  )
@@ -144,36 +197,11 @@ with gr.Blocks(title="Medical MCQ Predictor") as demo:
144
  # Mount Gradio app to FastAPI
145
  app = gr.mount_gradio_app(app, demo, path="/")
146
 
147
- @app.on_event("startup")
148
- async def startup_event():
149
- load_model()
150
-
151
  @app.get("/dataset/question")
152
  async def get_dataset_question(index: Optional[int] = None, random_question: bool = False):
153
  """Get a question from the MedMCQA dataset"""
154
  try:
155
- if dataset is None:
156
- raise HTTPException(status_code=500, detail="Dataset not loaded")
157
-
158
- if random_question:
159
- index = random.randint(0, len(dataset['train']) - 1)
160
- elif index is None:
161
- raise HTTPException(status_code=400, detail="Either index or random_question must be provided")
162
-
163
- question_data = dataset['train'][index]
164
-
165
- question = DatasetQuestion(
166
- question=question_data['question'],
167
- opa=question_data['opa'],
168
- opb=question_data['opb'],
169
- opc=question_data['opc'],
170
- opd=question_data['opd'],
171
- cop=question_data['cop'] if 'cop' in question_data else None,
172
- exp=question_data['exp'] if 'exp' in question_data else None
173
- )
174
-
175
- return question
176
-
177
  except Exception as e:
178
  raise HTTPException(status_code=500, detail=str(e))
179
 
@@ -183,35 +211,42 @@ async def predict(request: QuestionRequest):
183
  raise HTTPException(status_code=400, detail="Exactly 4 options are required")
184
 
185
  try:
186
- inputs = []
187
- for option in request.options:
188
- text = f"{request.question} {option}"
189
- inputs.append(text)
190
 
191
- encodings = tokenizer(
192
- inputs,
 
 
193
  padding=True,
194
  truncation=True,
195
- max_length=512,
196
- return_tensors="pt"
197
  )
198
 
199
  device = next(model.parameters()).device
200
- encodings = {k: v.to(device) for k, v in encodings.items()}
201
 
 
202
  with torch.no_grad():
203
- outputs = model(**encodings)
204
- logits = outputs.logits
205
- probabilities = torch.softmax(logits, dim=1)[0].tolist()
206
- predicted_class = torch.argmax(logits, dim=1).item()
 
 
 
 
 
 
 
 
 
 
207
 
208
  response = {
209
- "predicted_option": request.options[predicted_class],
210
- "option_index": predicted_class,
211
- "confidence": probabilities[predicted_class],
212
- "probabilities": {
213
- f"option_{i}": prob for i, prob in enumerate(probabilities)
214
- }
215
  }
216
 
217
  return response
 
2
  from fastapi.middleware.cors import CORSMiddleware
3
  from pydantic import BaseModel
4
  import torch
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer
6
  import os
7
  from datasets import load_dataset
8
  import random
9
+ from typing import Optional, List, Tuple, Union
10
  import gradio as gr
11
+ from contextlib import asynccontextmanager
12
 
13
+ # Global variables
14
+ model = None
15
+ tokenizer = None
16
+ dataset = None
17
+
18
+ @asynccontextmanager
19
+ async def lifespan(app: FastAPI):
20
+ # Startup: Load the model
21
+ global model, tokenizer, dataset
22
+ try:
23
+ # Load your fine-tuned model and tokenizer
24
+ model_name = os.getenv("MODEL_NAME", "rgb2gbr/BioXP-0.5B-MedMCQA")
25
+ model = AutoModelForCausalLM.from_pretrained(model_name)
26
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
27
+
28
+ # Load MedMCQA dataset
29
+ dataset = load_dataset("openlifescienceai/medmcqa")
30
+
31
+ # Move model to GPU if available
32
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
33
+ model = model.to(device)
34
+ model.eval()
35
+ except Exception as e:
36
+ print(f"Error loading model: {str(e)}")
37
+ raise e
38
+
39
+ yield # This is where FastAPI serves the application
40
+
41
+ # Shutdown: Clean up resources if needed
42
+ if model is not None:
43
+ del model
44
+ if tokenizer is not None:
45
+ del tokenizer
46
+ if dataset is not None:
47
+ del dataset
48
+ torch.cuda.empty_cache()
49
+
50
+ app = FastAPI(lifespan=lifespan)
51
 
52
  # Add CORS middleware for Gradio
53
  app.add_middleware(
 
72
  cop: Optional[int] = None # Correct option (0-3)
73
  exp: Optional[str] = None # Explanation if available
74
 
75
+ def format_prompt(question: str, options: List[str]) -> str:
76
+ """Format the prompt for the model"""
77
+ prompt = f"Question: {question}\n\nOptions:\n"
78
+ for i, opt in enumerate(options):
79
+ prompt += f"{chr(65+i)}. {opt}\n"
80
+ prompt += "\nAnswer:"
81
+ return prompt
82
 
83
+ def get_question(index: Optional[int] = None, random_question: bool = False, format: str = "api") -> Union[DatasetQuestion, Tuple[str, str, str, str, str]]:
84
+ """
85
+ Get a question from the dataset.
86
+ Args:
87
+ index: Optional question index
88
+ random_question: Whether to get a random question
89
+ format: 'api' for DatasetQuestion object, 'gradio' for tuple
90
+ """
91
+ if dataset is None:
92
+ raise Exception("Dataset not loaded")
93
+
94
+ if random_question:
95
+ index = random.randint(0, len(dataset['train']) - 1)
96
+ elif index is None:
97
+ raise ValueError("Either index or random_question must be provided")
98
+
99
+ question_data = dataset['train'][index]
100
+
101
+ if format == "gradio":
102
+ return (
103
+ question_data['question'],
104
+ question_data['opa'],
105
+ question_data['opb'],
106
+ question_data['opc'],
107
+ question_data['opd']
108
+ )
109
+
110
+ return DatasetQuestion(
111
+ question=question_data['question'],
112
+ opa=question_data['opa'],
113
+ opb=question_data['opb'],
114
+ opc=question_data['opc'],
115
+ opd=question_data['opd'],
116
+ cop=question_data['cop'] if 'cop' in question_data else None,
117
+ exp=question_data['exp'] if 'exp' in question_data else None
118
+ )
119
 
120
  def predict_gradio(question: str, option_a: str, option_b: str, option_c: str, option_d: str):
121
  """Gradio interface prediction function"""
122
  try:
123
  options = [option_a, option_b, option_c, option_d]
 
 
 
 
124
 
125
+ # Format the prompt
126
+ prompt = format_prompt(question, options)
127
+
128
+ # Tokenize the input
129
+ inputs = tokenizer(
130
+ prompt,
131
+ return_tensors="pt",
132
  padding=True,
133
  truncation=True,
134
+ max_length=512
 
135
  )
136
 
137
  device = next(model.parameters()).device
138
+ inputs = {k: v.to(device) for k, v in inputs.items()}
139
 
140
+ # Generate prediction
141
  with torch.no_grad():
142
+ outputs = model.generate(
143
+ **inputs,
144
+ max_new_tokens=10,
145
+ num_return_sequences=1,
146
+ temperature=0.7,
147
+ do_sample=False,
148
+ pad_token_id=tokenizer.eos_token_id
149
+ )
150
+
151
+ # Decode the output
152
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
153
+
154
+ # Extract the answer from the prediction
155
+ answer = prediction.split("Answer:")[-1].strip()
156
 
157
  # Format the output for Gradio
158
+ result = f"Model Output:\n{prediction}\n\n"
159
+ result += f"Extracted Answer: {answer}"
 
 
160
 
161
  return result
162
 
163
  except Exception as e:
164
  return f"Error: {str(e)}"
165
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
166
  # Create Gradio interface
167
  with gr.Blocks(title="Medical MCQ Predictor") as demo:
168
  gr.Markdown("# Medical MCQ Predictor")
 
189
  )
190
 
191
  random_btn.click(
192
+ fn=lambda: get_question(random_question=True, format="gradio"),
193
  inputs=[],
194
  outputs=[question, option_a, option_b, option_c, option_d]
195
  )
 
197
  # Mount Gradio app to FastAPI
198
  app = gr.mount_gradio_app(app, demo, path="/")
199
 
 
 
 
 
200
  @app.get("/dataset/question")
201
  async def get_dataset_question(index: Optional[int] = None, random_question: bool = False):
202
  """Get a question from the MedMCQA dataset"""
203
  try:
204
+ return get_question(index=index, random_question=random_question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
205
  except Exception as e:
206
  raise HTTPException(status_code=500, detail=str(e))
207
 
 
211
  raise HTTPException(status_code=400, detail="Exactly 4 options are required")
212
 
213
  try:
214
+ # Format the prompt
215
+ prompt = format_prompt(request.question, request.options)
 
 
216
 
217
+ # Tokenize the input
218
+ inputs = tokenizer(
219
+ prompt,
220
+ return_tensors="pt",
221
  padding=True,
222
  truncation=True,
223
+ max_length=512
 
224
  )
225
 
226
  device = next(model.parameters()).device
227
+ inputs = {k: v.to(device) for k, v in inputs.items()}
228
 
229
+ # Generate prediction
230
  with torch.no_grad():
231
+ outputs = model.generate(
232
+ **inputs,
233
+ max_new_tokens=10,
234
+ num_return_sequences=1,
235
+ temperature=0.7,
236
+ do_sample=False,
237
+ pad_token_id=tokenizer.eos_token_id
238
+ )
239
+
240
+ # Decode the output
241
+ prediction = tokenizer.decode(outputs[0], skip_special_tokens=True)
242
+
243
+ # Extract the answer from the prediction
244
+ answer = prediction.split("Answer:")[-1].strip()
245
 
246
  response = {
247
+ "model_output": prediction,
248
+ "extracted_answer": answer,
249
+ "full_response": prediction
 
 
 
250
  }
251
 
252
  return response