hponepyae commited on
Commit
46668b2
·
verified ·
1 Parent(s): 6ef5bdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -36
app.py CHANGED
@@ -34,34 +34,24 @@ except Exception as e:
34
  def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
35
  """
36
  Manages the conversational flow, persisting the image across turns.
37
-
38
- Args:
39
- user_input (str): Text from the user.
40
- history_state (list): Stateful conversation history.
41
- new_image_upload (PIL.Image): A new image uploaded in this turn.
42
- image_state (PIL.Image): The image persisted from a previous turn.
43
-
44
- Returns:
45
- tuple: Updated values for all relevant Gradio components.
46
  """
47
  if not model_loaded:
48
  history_state.append((user_input, "Error: The model could not be loaded."))
49
  return history_state, history_state, None, None, ""
50
 
51
- # FIX: Determine which image to use. A new upload takes precedence.
52
- # This is the key to solving the "image amnesia" problem.
53
  current_image = new_image_upload if new_image_upload is not None else image_state
54
 
55
- # If this is the *first* turn with an image, add the <image> token.
56
- # Don't add it again if the image is just being carried over in the state.
57
  if new_image_upload is not None:
58
  model_input_text = f"<image>\n{user_input}"
59
  else:
60
  model_input_text = user_input
61
 
62
- system_prompt = "You are an expert, empathetic AI medical assistant..." # Keep your detailed prompt
 
 
 
63
 
64
- # Construct the full conversation history
65
  conversation = [{"role": "system", "content": system_prompt}]
66
  for turn_input, assistant_output in history_state:
67
  conversation.append({"role": "user", "content": turn_input})
@@ -73,41 +63,44 @@ def symptom_checker_chat(user_input, history_state, new_image_upload, image_stat
73
  conversation, tokenize=False, add_generation_prompt=True
74
  )
75
 
76
- # Process inputs. Crucially, pass `current_image` to the processor.
77
  try:
 
78
  if current_image:
79
- inputs = processor(text=prompt, images=current_image, return_tensors="pt").to(model.device, dtype)
 
80
  else:
81
  inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
82
 
 
83
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
84
  input_token_len = inputs["input_ids"].shape[1]
85
  generated_tokens = outputs[:, input_token_len:]
86
  clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
87
 
88
  except Exception as e:
89
- print(f"Error during model generation: {e}")
90
- clean_response = "An error occurred during generation. Please check the logs."
91
-
92
- # Update the text history with the model-aware input
 
 
 
 
 
93
  history_state.append((model_input_text, clean_response))
94
-
95
- # Create the display history without the special tokens
96
  display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
97
 
98
- # Return everything:
99
- # 1. display_history -> to the chatbot UI
100
- # 2. history_state -> to the text history state
101
- # 3. current_image -> to the image state to be persisted
102
- # 4. None -> to the image upload box to clear it
103
- # 5. "" -> to the text box to clear it
104
  return display_history, history_state, current_image, None, ""
105
 
106
- # --- Gradio Interface ---
107
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
108
- gr.Markdown("# AI Symptom Checker powered by MedGemma\n...") # Keep your intro markdown
 
 
 
 
 
109
 
110
- # FIX: Add a new state component to hold the image across turns
111
  image_state = gr.State(None)
112
 
113
  chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
@@ -117,10 +110,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
117
  image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
118
 
119
  with gr.Row():
120
- text_box = gr.Textbox(label="Describe your symptoms...", scale=4)
121
  submit_btn = gr.Button("Send", variant="primary", scale=1)
122
 
123
- # FIX: Update the clear function to also clear the new image_state
124
  def clear_all():
125
  return [], [], None, None, ""
126
 
@@ -131,11 +123,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
131
  queue=False
132
  )
133
 
134
- # Combine submit actions into a single function for DRY principle
135
  def on_submit(user_input, history, new_image, persisted_image):
136
  return symptom_checker_chat(user_input, history, new_image, persisted_image)
137
 
138
- # FIX: Update inputs and outputs to include the new image_state
139
  submit_btn.click(
140
  fn=on_submit,
141
  inputs=[text_box, chat_history, image_box, image_state],
 
34
  def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
35
  """
36
  Manages the conversational flow, persisting the image across turns.
37
+ Includes robust error reporting directly in the UI for debugging.
 
 
 
 
 
 
 
 
38
  """
39
  if not model_loaded:
40
  history_state.append((user_input, "Error: The model could not be loaded."))
41
  return history_state, history_state, None, None, ""
42
 
 
 
43
  current_image = new_image_upload if new_image_upload is not None else image_state
44
 
 
 
45
  if new_image_upload is not None:
46
  model_input_text = f"<image>\n{user_input}"
47
  else:
48
  model_input_text = user_input
49
 
50
+ system_prompt = """
51
+ You are an expert, empathetic AI medical assistant... (your full prompt here)
52
+ ***Disclaimer: I am an AI assistant and not a medical professional...***
53
+ """
54
 
 
55
  conversation = [{"role": "system", "content": system_prompt}]
56
  for turn_input, assistant_output in history_state:
57
  conversation.append({"role": "user", "content": turn_input})
 
63
  conversation, tokenize=False, add_generation_prompt=True
64
  )
65
 
 
66
  try:
67
+ # Pass the image and text to the processor for encoding
68
  if current_image:
69
+ # FIX 2: Ensure the image is passed as a list. This is more robust.
70
+ inputs = processor(text=prompt, images=[current_image], return_tensors="pt").to(model.device, dtype)
71
  else:
72
  inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
73
 
74
+ # Generate the response
75
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
76
  input_token_len = inputs["input_ids"].shape[1]
77
  generated_tokens = outputs[:, input_token_len:]
78
  clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
79
 
80
  except Exception as e:
81
+ # FIX 1: EXPOSE THE REAL ERROR IN THE UI FOR DEBUGGING
82
+ # This is the most important change. We will now see the true error message.
83
+ print(f"Caught a critical exception during generation: {e}", flush=True)
84
+ clean_response = (
85
+ "An error occurred during generation. This is the technical details:\n\n"
86
+ f"```\n{type(e).__name__}: {e}\n```"
87
+ )
88
+
89
+ # Update history and return values
90
  history_state.append((model_input_text, clean_response))
 
 
91
  display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
92
 
 
 
 
 
 
 
93
  return display_history, history_state, current_image, None, ""
94
 
95
+ # --- Gradio Interface (No changes needed here from the last version) ---
96
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
97
+ gr.Markdown(
98
+ """
99
+ # AI Symptom Checker powered by MedGemma
100
+ Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
101
+ """
102
+ )
103
 
 
104
  image_state = gr.State(None)
105
 
106
  chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
 
110
  image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
111
 
112
  with gr.Row():
113
+ text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
114
  submit_btn = gr.Button("Send", variant="primary", scale=1)
115
 
 
116
  def clear_all():
117
  return [], [], None, None, ""
118
 
 
123
  queue=False
124
  )
125
 
 
126
  def on_submit(user_input, history, new_image, persisted_image):
127
  return symptom_checker_chat(user_input, history, new_image, persisted_image)
128
 
 
129
  submit_btn.click(
130
  fn=on_submit,
131
  inputs=[text_box, chat_history, image_box, image_state],