hponepyae commited on
Commit
b47c12e
·
verified ·
1 Parent(s): 46668b2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +47 -29
app.py CHANGED
@@ -33,8 +33,8 @@ except Exception as e:
33
  @spaces.GPU
34
  def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
35
  """
36
- Manages the conversational flow, persisting the image across turns.
37
- Includes robust error reporting directly in the UI for debugging.
38
  """
39
  if not model_loaded:
40
  history_state.append((user_input, "Error: The model could not be loaded."))
@@ -42,34 +42,41 @@ def symptom_checker_chat(user_input, history_state, new_image_upload, image_stat
42
 
43
  current_image = new_image_upload if new_image_upload is not None else image_state
44
 
45
- if new_image_upload is not None:
46
- model_input_text = f"<image>\n{user_input}"
47
- else:
48
- model_input_text = user_input
49
-
50
- system_prompt = """
51
- You are an expert, empathetic AI medical assistant... (your full prompt here)
52
- ***Disclaimer: I am an AI assistant and not a medical professional...***
53
- """
54
-
55
- conversation = [{"role": "system", "content": system_prompt}]
56
  for turn_input, assistant_output in history_state:
57
- conversation.append({"role": "user", "content": turn_input})
 
 
58
  if assistant_output:
59
- conversation.append({"role": "assistant", "content": assistant_output})
60
- conversation.append({"role": "user", "content": model_input_text})
61
-
62
- prompt = processor.apply_chat_template(
63
- conversation, tokenize=False, add_generation_prompt=True
64
- )
 
 
 
 
 
 
 
 
 
65
 
66
  try:
67
- # Pass the image and text to the processor for encoding
68
  if current_image:
69
- # FIX 2: Ensure the image is passed as a list. This is more robust.
70
- inputs = processor(text=prompt, images=[current_image], return_tensors="pt").to(model.device, dtype)
71
  else:
72
- inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
73
 
74
  # Generate the response
75
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
@@ -78,21 +85,29 @@ def symptom_checker_chat(user_input, history_state, new_image_upload, image_stat
78
  clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
79
 
80
  except Exception as e:
81
- # FIX 1: EXPOSE THE REAL ERROR IN THE UI FOR DEBUGGING
82
- # This is the most important change. We will now see the true error message.
83
  print(f"Caught a critical exception during generation: {e}", flush=True)
 
84
  clean_response = (
85
  "An error occurred during generation. This is the technical details:\n\n"
86
  f"```\n{type(e).__name__}: {e}\n```"
87
  )
88
 
89
- # Update history and return values
90
- history_state.append((model_input_text, clean_response))
 
 
 
 
 
 
 
 
91
  display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
92
 
 
93
  return display_history, history_state, current_image, None, ""
94
 
95
- # --- Gradio Interface (No changes needed here from the last version) ---
96
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
97
  gr.Markdown(
98
  """
@@ -124,6 +139,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
124
  )
125
 
126
  def on_submit(user_input, history, new_image, persisted_image):
 
 
 
127
  return symptom_checker_chat(user_input, history, new_image, persisted_image)
128
 
129
  submit_btn.click(
 
33
  @spaces.GPU
34
  def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
35
  """
36
+ Manages the conversational flow by manually building the prompt to ensure
37
+ correct handling of the <image> token.
38
  """
39
  if not model_loaded:
40
  history_state.append((user_input, "Error: The model could not be loaded."))
 
42
 
43
  current_image = new_image_upload if new_image_upload is not None else image_state
44
 
45
+ # --- FIX: Manual Prompt Construction ---
46
+ # This gives us full control and bypasses the opaque apply_chat_template behavior.
47
+
48
+ # System prompt is not included in the turns, but as a prefix.
49
+ system_prompt = "You are an expert, empathetic AI medical assistant..." # Keep your full system prompt
50
+
51
+ # Build the prompt from history
52
+ prompt_parts = []
 
 
 
53
  for turn_input, assistant_output in history_state:
54
+ # Add a user turn from history
55
+ prompt_parts.append(f"<start_of_turn>user\n{turn_input}<end_of_turn>\n")
56
+ # Add a model turn from history
57
  if assistant_output:
58
+ prompt_parts.append(f"<start_of_turn>model\n{assistant_output}<end_of_turn>\n")
59
+
60
+ # Add the current user turn
61
+ prompt_parts.append("<start_of_turn>user\n")
62
+ # The MOST IMPORTANT PART: Add the <image> token if an image is present.
63
+ # We add it for a new upload OR if we're in a conversation that already had an image.
64
+ if current_image:
65
+ prompt_parts.append("<image>\n")
66
+ prompt_parts.append(f"{user_input}<end_of_turn>\n")
67
+
68
+ # Add the generation prompt for the model to start its response
69
+ prompt_parts.append("<start_of_turn>model\n")
70
+
71
+ # Join everything into a single string
72
+ final_prompt = "".join(prompt_parts)
73
 
74
  try:
75
+ # Process the inputs using our manually built prompt
76
  if current_image:
77
+ inputs = processor(text=final_prompt, images=[current_image], return_tensors="pt").to(model.device, dtype)
 
78
  else:
79
+ inputs = processor(text=final_prompt, return_tensors="pt").to(model.device, dtype)
80
 
81
  # Generate the response
82
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
 
85
  clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
86
 
87
  except Exception as e:
 
 
88
  print(f"Caught a critical exception during generation: {e}", flush=True)
89
+ # Display the real error in the UI for easier debugging
90
  clean_response = (
91
  "An error occurred during generation. This is the technical details:\n\n"
92
  f"```\n{type(e).__name__}: {e}\n```"
93
  )
94
 
95
+ # --- History Management ---
96
+ # For history, we need to save the user_input along with a marker if an image was present
97
+ # We use the same <image>\n token we've been using as that marker.
98
+ history_input = user_input
99
+ if current_image:
100
+ history_input = f"<image>\n{user_input}"
101
+
102
+ history_state.append((history_input, clean_response))
103
+
104
+ # Create display history without the special tokens
105
  display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
106
 
107
+ # Return all updated values
108
  return display_history, history_state, current_image, None, ""
109
 
110
+ # --- Gradio Interface ---
111
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
112
  gr.Markdown(
113
  """
 
139
  )
140
 
141
  def on_submit(user_input, history, new_image, persisted_image):
142
+ # We need to handle the case where the user input is empty
143
+ if not user_input.strip():
144
+ return history, history, persisted_image, None, ""
145
  return symptom_checker_chat(user_input, history, new_image, persisted_image)
146
 
147
  submit_btn.click(