hponepyae commited on
Commit
eef3b89
·
verified ·
1 Parent(s): 77793f4

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -109
app.py CHANGED
@@ -8,18 +8,19 @@ import spaces
8
  # --- Configuration & Model Loading ---
9
  # Use the pipeline, which is more robust as seen in the working example
10
  print("Loading MedGemma model via pipeline...")
 
 
11
  try:
12
  pipe = pipeline(
13
- "image-to-text", # The correct task for this model
14
  model="google/medgemma-4b-it",
15
- model_kwargs={"torch_dtype": torch.bfloat16}, # Pass dtype here
16
  device_map="auto",
17
  token=os.environ.get("HF_TOKEN")
18
  )
19
  model_loaded = True
20
  print("Model loaded successfully!")
21
  except Exception as e:
22
- model_loaded = False
23
  print(f"Error loading model: {e}")
24
 
25
 
@@ -36,45 +37,31 @@ def symptom_checker_chat(user_input, history_for_display, new_image_upload, imag
36
 
37
  current_image = new_image_upload if new_image_upload is not None else image_state
38
 
39
- # --- THE CORRECT IMPLEMENTATION ---
40
- # Build the 'messages' list using the exact format from the working X-ray app.
41
  messages = []
42
 
43
- # Optional: System prompt can be added here if needed, following the same format.
44
-
45
  # Process the conversation history
46
  for user_msg, assistant_msg in history_for_display:
47
- # For history turns, we assume the image was part of the first turn (handled below).
48
- # So, all historical messages are just text.
49
- messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
50
  if assistant_msg:
51
- messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
52
 
53
  # Add the current user turn
54
- current_user_content = [{"type": "text", "text": user_input}]
55
- # If there's an image for the conversation, add it to the first user turn's content
56
- if current_image is not None and not history_for_display: # Only for the very first message
57
- current_user_content.append({"type": "image"}) # The pipeline handles the image object separately
58
-
59
- messages.append({"role": "user", "content": current_user_content})
60
 
61
  try:
62
- # Generate analysis using the pipeline. It's much simpler.
63
- # We pass the image separately if it exists.
64
  if current_image:
65
- output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512})
 
66
  else:
67
  # If no image, the pipeline can work with just the prompt
68
- output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512})
69
 
70
- # The pipeline's output structure can be complex; we need to extract the final text.
71
- # It's usually in the last dictionary of the generated list.
72
- result = output[0]["generated_text"]
73
- if isinstance(result, list):
74
- # Find the last text content from the model's response
75
- clean_response = next((item['text'] for item in reversed(result) if item['type'] == 'text'), "Sorry, I couldn't generate a response.")
76
- else: # Simpler text-only output
77
- clean_response = result
78
 
79
  except Exception as e:
80
  print(f"Caught a critical exception during generation: {e}", flush=True)
@@ -86,73 +73,6 @@ def symptom_checker_chat(user_input, history_for_display, new_image_upload, imag
86
  history_for_display.append((user_input, clean_response))
87
  return history_for_display, current_image, None, ""
88
 
89
- # --- Gradio Interface (Mostly unchanged) ---
90
- with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
91
- gr.Markdown(
92
- """
93
- # AI Symptom Checker powered by MedGemma
94
- Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
95
- """
96
- )
97
-
98
- image_state = gr.State(None)
99
- chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
100
- chat_history = gr.State([])
101
-
102
- with gr.Row():
103
- image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
104
-
105
- with gr.Row():
106
- text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
107
- submit_btn = gr.Button("Send", variant="primary", scale=1)
108
-
109
- def clear_all():
110
- return [], None, None, ""
111
-
112
- clear_btn = gr.Button("Start New Conversation")
113
- clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
114
-
115
- def on_submit(user_input, display_history, new_image, persisted_image):
116
- if not user_input.strip() and not new_image:
117
- return display_history, persisted_image, None, ""
118
- # The display history IS our history state now
119
- return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
120
-
121
- submit_btn.click(
122
- fn=on_submit,
123
- inputs=[text_box, chat_history, image_box, image_state],
124
- outputs=[chat_history, image_state, image_box, text_box]
125
- )
126
- text_box.submit(
127
- fn=on_submit,
128
- inputs=[text_box, chat_history, image_box, image_state],
129
- outputs=[chat_history, image_state, image_box, text_box]
130
- )
131
-
132
- if __name__ == "__main__":
133
- demo.launch(debug=True)
134
- # Generate the response
135
- outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
136
-
137
- # Decode only the newly generated part
138
- input_token_len = inputs["input_ids"].shape[1]
139
- generated_tokens = outputs[:, input_token_len:]
140
- clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
141
-
142
- except Exception as e:
143
- print(f"Caught a critical exception during generation: {e}", flush=True)
144
- # Display the real error in the UI for easier debugging
145
- clean_response = (
146
- "An error occurred during generation. This is the technical details:\n\n"
147
- f"```\n{type(e).__name__}: {e}\n```"
148
- )
149
-
150
- # Update the display history
151
- history_for_display.append((user_input, clean_response))
152
-
153
- # Return all updated values
154
- return history_for_display, current_image, None, ""
155
-
156
  # --- Gradio Interface ---
157
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
158
  gr.Markdown(
@@ -162,44 +82,34 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
162
  """
163
  )
164
 
165
- # State to hold the image across an entire conversation
166
  image_state = gr.State(None)
167
-
168
  chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
169
- # The history state will now just be for display, a simple list of (text, text) tuples.
170
  chat_history = gr.State([])
171
 
172
  with gr.Row():
173
  image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
174
 
175
  with gr.Row():
176
- text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
177
  submit_btn = gr.Button("Send", variant="primary", scale=1)
178
 
179
- # The clear function now resets all three states
180
  def clear_all():
181
  return [], None, None, ""
182
 
183
  clear_btn = gr.Button("Start New Conversation")
184
- clear_btn.click(
185
- fn=clear_all,
186
- outputs=[chat_history, image_state, image_box, text_box],
187
- queue=False
188
- )
189
 
190
- # The submit handler function
191
  def on_submit(user_input, display_history, new_image, persisted_image):
192
  if not user_input.strip() and not new_image:
193
  return display_history, persisted_image, None, ""
194
  return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
195
 
196
- # Wire up the events
197
  submit_btn.click(
198
  fn=on_submit,
199
  inputs=[text_box, chat_history, image_box, image_state],
200
  outputs=[chat_history, image_state, image_box, text_box]
201
  )
202
-
203
  text_box.submit(
204
  fn=on_submit,
205
  inputs=[text_box, chat_history, image_box, image_state],
@@ -208,3 +118,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
208
 
209
  if __name__ == "__main__":
210
  demo.launch(debug=True)
 
 
8
  # --- Configuration & Model Loading ---
9
  # Use the pipeline, which is more robust as seen in the working example
10
  print("Loading MedGemma model via pipeline...")
11
+ model_loaded = False
12
+ pipe = None
13
  try:
14
  pipe = pipeline(
15
+ "image-to-text",
16
  model="google/medgemma-4b-it",
17
+ model_kwargs={"torch_dtype": torch.bfloat16},
18
  device_map="auto",
19
  token=os.environ.get("HF_TOKEN")
20
  )
21
  model_loaded = True
22
  print("Model loaded successfully!")
23
  except Exception as e:
 
24
  print(f"Error loading model: {e}")
25
 
26
 
 
37
 
38
  current_image = new_image_upload if new_image_upload is not None else image_state
39
 
40
+ # Build the 'messages' list using the correct format for the pipeline
 
41
  messages = []
42
 
 
 
43
  # Process the conversation history
44
  for user_msg, assistant_msg in history_for_display:
45
+ messages.append({"role": "user", "content": user_msg})
 
 
46
  if assistant_msg:
47
+ messages.append({"role": "assistant", "content": assistant_msg})
48
 
49
  # Add the current user turn
50
+ messages.append({"role": "user", "content": user_input})
 
 
 
 
 
51
 
52
  try:
53
+ # The pipeline call is simpler. We pass the image as the main argument
54
+ # and the text conversation as the `prompt`.
55
  if current_image:
56
+ # The image goes first, the prompt kwarg contains the conversation history
57
+ output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512, "do_sample": True, "temperature": 0.7})
58
  else:
59
  # If no image, the pipeline can work with just the prompt
60
+ output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512, "do_sample": True, "temperature": 0.7})
61
 
62
+ # The pipeline output structure contains the full conversation.
63
+ # We want the content of the last message, which is the model's reply.
64
+ clean_response = output[0]["generated_text"][-1]['content']
 
 
 
 
 
65
 
66
  except Exception as e:
67
  print(f"Caught a critical exception during generation: {e}", flush=True)
 
73
  history_for_display.append((user_input, clean_response))
74
  return history_for_display, current_image, None, ""
75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
  # --- Gradio Interface ---
77
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
78
  gr.Markdown(
 
82
  """
83
  )
84
 
 
85
  image_state = gr.State(None)
 
86
  chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
 
87
  chat_history = gr.State([])
88
 
89
  with gr.Row():
90
  image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
91
 
92
  with gr.Row():
93
+ text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4)
94
  submit_btn = gr.Button("Send", variant="primary", scale=1)
95
 
 
96
  def clear_all():
97
  return [], None, None, ""
98
 
99
  clear_btn = gr.Button("Start New Conversation")
100
+ clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
 
 
 
 
101
 
 
102
  def on_submit(user_input, display_history, new_image, persisted_image):
103
  if not user_input.strip() and not new_image:
104
  return display_history, persisted_image, None, ""
105
  return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
106
 
107
+ # Event Handlers for submit button and enter key
108
  submit_btn.click(
109
  fn=on_submit,
110
  inputs=[text_box, chat_history, image_box, image_state],
111
  outputs=[chat_history, image_state, image_box, text_box]
112
  )
 
113
  text_box.submit(
114
  fn=on_submit,
115
  inputs=[text_box, chat_history, image_box, image_state],
 
118
 
119
  if __name__ == "__main__":
120
  demo.launch(debug=True)
121
+