Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -34,34 +34,24 @@ except Exception as e:
|
|
34 |
def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
|
35 |
"""
|
36 |
Manages the conversational flow, persisting the image across turns.
|
37 |
-
|
38 |
-
Args:
|
39 |
-
user_input (str): Text from the user.
|
40 |
-
history_state (list): Stateful conversation history.
|
41 |
-
new_image_upload (PIL.Image): A new image uploaded in this turn.
|
42 |
-
image_state (PIL.Image): The image persisted from a previous turn.
|
43 |
-
|
44 |
-
Returns:
|
45 |
-
tuple: Updated values for all relevant Gradio components.
|
46 |
"""
|
47 |
if not model_loaded:
|
48 |
history_state.append((user_input, "Error: The model could not be loaded."))
|
49 |
return history_state, history_state, None, None, ""
|
50 |
|
51 |
-
# FIX: Determine which image to use. A new upload takes precedence.
|
52 |
-
# This is the key to solving the "image amnesia" problem.
|
53 |
current_image = new_image_upload if new_image_upload is not None else image_state
|
54 |
|
55 |
-
# If this is the *first* turn with an image, add the <image> token.
|
56 |
-
# Don't add it again if the image is just being carried over in the state.
|
57 |
if new_image_upload is not None:
|
58 |
model_input_text = f"<image>\n{user_input}"
|
59 |
else:
|
60 |
model_input_text = user_input
|
61 |
|
62 |
-
system_prompt = "
|
|
|
|
|
|
|
63 |
|
64 |
-
# Construct the full conversation history
|
65 |
conversation = [{"role": "system", "content": system_prompt}]
|
66 |
for turn_input, assistant_output in history_state:
|
67 |
conversation.append({"role": "user", "content": turn_input})
|
@@ -73,41 +63,44 @@ def symptom_checker_chat(user_input, history_state, new_image_upload, image_stat
|
|
73 |
conversation, tokenize=False, add_generation_prompt=True
|
74 |
)
|
75 |
|
76 |
-
# Process inputs. Crucially, pass `current_image` to the processor.
|
77 |
try:
|
|
|
78 |
if current_image:
|
79 |
-
|
|
|
80 |
else:
|
81 |
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
|
82 |
|
|
|
83 |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
84 |
input_token_len = inputs["input_ids"].shape[1]
|
85 |
generated_tokens = outputs[:, input_token_len:]
|
86 |
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
|
87 |
|
88 |
except Exception as e:
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
-
|
|
|
|
|
|
|
|
|
|
|
93 |
history_state.append((model_input_text, clean_response))
|
94 |
-
|
95 |
-
# Create the display history without the special tokens
|
96 |
display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
|
97 |
|
98 |
-
# Return everything:
|
99 |
-
# 1. display_history -> to the chatbot UI
|
100 |
-
# 2. history_state -> to the text history state
|
101 |
-
# 3. current_image -> to the image state to be persisted
|
102 |
-
# 4. None -> to the image upload box to clear it
|
103 |
-
# 5. "" -> to the text box to clear it
|
104 |
return display_history, history_state, current_image, None, ""
|
105 |
|
106 |
-
# --- Gradio Interface ---
|
107 |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
108 |
-
gr.Markdown(
|
|
|
|
|
|
|
|
|
|
|
109 |
|
110 |
-
# FIX: Add a new state component to hold the image across turns
|
111 |
image_state = gr.State(None)
|
112 |
|
113 |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
|
@@ -117,10 +110,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
117 |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
|
118 |
|
119 |
with gr.Row():
|
120 |
-
text_box = gr.Textbox(label="Describe your symptoms...", scale=4)
|
121 |
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
122 |
|
123 |
-
# FIX: Update the clear function to also clear the new image_state
|
124 |
def clear_all():
|
125 |
return [], [], None, None, ""
|
126 |
|
@@ -131,11 +123,9 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
131 |
queue=False
|
132 |
)
|
133 |
|
134 |
-
# Combine submit actions into a single function for DRY principle
|
135 |
def on_submit(user_input, history, new_image, persisted_image):
|
136 |
return symptom_checker_chat(user_input, history, new_image, persisted_image)
|
137 |
|
138 |
-
# FIX: Update inputs and outputs to include the new image_state
|
139 |
submit_btn.click(
|
140 |
fn=on_submit,
|
141 |
inputs=[text_box, chat_history, image_box, image_state],
|
|
|
34 |
def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
|
35 |
"""
|
36 |
Manages the conversational flow, persisting the image across turns.
|
37 |
+
Includes robust error reporting directly in the UI for debugging.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
38 |
"""
|
39 |
if not model_loaded:
|
40 |
history_state.append((user_input, "Error: The model could not be loaded."))
|
41 |
return history_state, history_state, None, None, ""
|
42 |
|
|
|
|
|
43 |
current_image = new_image_upload if new_image_upload is not None else image_state
|
44 |
|
|
|
|
|
45 |
if new_image_upload is not None:
|
46 |
model_input_text = f"<image>\n{user_input}"
|
47 |
else:
|
48 |
model_input_text = user_input
|
49 |
|
50 |
+
system_prompt = """
|
51 |
+
You are an expert, empathetic AI medical assistant... (your full prompt here)
|
52 |
+
***Disclaimer: I am an AI assistant and not a medical professional...***
|
53 |
+
"""
|
54 |
|
|
|
55 |
conversation = [{"role": "system", "content": system_prompt}]
|
56 |
for turn_input, assistant_output in history_state:
|
57 |
conversation.append({"role": "user", "content": turn_input})
|
|
|
63 |
conversation, tokenize=False, add_generation_prompt=True
|
64 |
)
|
65 |
|
|
|
66 |
try:
|
67 |
+
# Pass the image and text to the processor for encoding
|
68 |
if current_image:
|
69 |
+
# FIX 2: Ensure the image is passed as a list. This is more robust.
|
70 |
+
inputs = processor(text=prompt, images=[current_image], return_tensors="pt").to(model.device, dtype)
|
71 |
else:
|
72 |
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
|
73 |
|
74 |
+
# Generate the response
|
75 |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
76 |
input_token_len = inputs["input_ids"].shape[1]
|
77 |
generated_tokens = outputs[:, input_token_len:]
|
78 |
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
|
79 |
|
80 |
except Exception as e:
|
81 |
+
# FIX 1: EXPOSE THE REAL ERROR IN THE UI FOR DEBUGGING
|
82 |
+
# This is the most important change. We will now see the true error message.
|
83 |
+
print(f"Caught a critical exception during generation: {e}", flush=True)
|
84 |
+
clean_response = (
|
85 |
+
"An error occurred during generation. This is the technical details:\n\n"
|
86 |
+
f"```\n{type(e).__name__}: {e}\n```"
|
87 |
+
)
|
88 |
+
|
89 |
+
# Update history and return values
|
90 |
history_state.append((model_input_text, clean_response))
|
|
|
|
|
91 |
display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
|
92 |
|
|
|
|
|
|
|
|
|
|
|
|
|
93 |
return display_history, history_state, current_image, None, ""
|
94 |
|
95 |
+
# --- Gradio Interface (No changes needed here from the last version) ---
|
96 |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
97 |
+
gr.Markdown(
|
98 |
+
"""
|
99 |
+
# AI Symptom Checker powered by MedGemma
|
100 |
+
Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
|
101 |
+
"""
|
102 |
+
)
|
103 |
|
|
|
104 |
image_state = gr.State(None)
|
105 |
|
106 |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
|
|
|
110 |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
|
111 |
|
112 |
with gr.Row():
|
113 |
+
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
|
114 |
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
115 |
|
|
|
116 |
def clear_all():
|
117 |
return [], [], None, None, ""
|
118 |
|
|
|
123 |
queue=False
|
124 |
)
|
125 |
|
|
|
126 |
def on_submit(user_input, history, new_image, persisted_image):
|
127 |
return symptom_checker_chat(user_input, history, new_image, persisted_image)
|
128 |
|
|
|
129 |
submit_btn.click(
|
130 |
fn=on_submit,
|
131 |
inputs=[text_box, chat_history, image_box, image_state],
|