Spaces:
Running
Running
Update app.py
Browse files
app.py
CHANGED
@@ -8,18 +8,19 @@ import spaces
|
|
8 |
# --- Configuration & Model Loading ---
|
9 |
# Use the pipeline, which is more robust as seen in the working example
|
10 |
print("Loading MedGemma model via pipeline...")
|
|
|
|
|
11 |
try:
|
12 |
pipe = pipeline(
|
13 |
-
"image-to-text",
|
14 |
model="google/medgemma-4b-it",
|
15 |
-
model_kwargs={"torch_dtype": torch.bfloat16},
|
16 |
device_map="auto",
|
17 |
token=os.environ.get("HF_TOKEN")
|
18 |
)
|
19 |
model_loaded = True
|
20 |
print("Model loaded successfully!")
|
21 |
except Exception as e:
|
22 |
-
model_loaded = False
|
23 |
print(f"Error loading model: {e}")
|
24 |
|
25 |
|
@@ -36,45 +37,31 @@ def symptom_checker_chat(user_input, history_for_display, new_image_upload, imag
|
|
36 |
|
37 |
current_image = new_image_upload if new_image_upload is not None else image_state
|
38 |
|
39 |
-
#
|
40 |
-
# Build the 'messages' list using the exact format from the working X-ray app.
|
41 |
messages = []
|
42 |
|
43 |
-
# Optional: System prompt can be added here if needed, following the same format.
|
44 |
-
|
45 |
# Process the conversation history
|
46 |
for user_msg, assistant_msg in history_for_display:
|
47 |
-
|
48 |
-
# So, all historical messages are just text.
|
49 |
-
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
|
50 |
if assistant_msg:
|
51 |
-
messages.append({"role": "assistant", "content":
|
52 |
|
53 |
# Add the current user turn
|
54 |
-
|
55 |
-
# If there's an image for the conversation, add it to the first user turn's content
|
56 |
-
if current_image is not None and not history_for_display: # Only for the very first message
|
57 |
-
current_user_content.append({"type": "image"}) # The pipeline handles the image object separately
|
58 |
-
|
59 |
-
messages.append({"role": "user", "content": current_user_content})
|
60 |
|
61 |
try:
|
62 |
-
#
|
63 |
-
#
|
64 |
if current_image:
|
65 |
-
|
|
|
66 |
else:
|
67 |
# If no image, the pipeline can work with just the prompt
|
68 |
-
output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512})
|
69 |
|
70 |
-
# The pipeline
|
71 |
-
#
|
72 |
-
|
73 |
-
if isinstance(result, list):
|
74 |
-
# Find the last text content from the model's response
|
75 |
-
clean_response = next((item['text'] for item in reversed(result) if item['type'] == 'text'), "Sorry, I couldn't generate a response.")
|
76 |
-
else: # Simpler text-only output
|
77 |
-
clean_response = result
|
78 |
|
79 |
except Exception as e:
|
80 |
print(f"Caught a critical exception during generation: {e}", flush=True)
|
@@ -86,73 +73,6 @@ def symptom_checker_chat(user_input, history_for_display, new_image_upload, imag
|
|
86 |
history_for_display.append((user_input, clean_response))
|
87 |
return history_for_display, current_image, None, ""
|
88 |
|
89 |
-
# --- Gradio Interface (Mostly unchanged) ---
|
90 |
-
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
91 |
-
gr.Markdown(
|
92 |
-
"""
|
93 |
-
# AI Symptom Checker powered by MedGemma
|
94 |
-
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
|
95 |
-
"""
|
96 |
-
)
|
97 |
-
|
98 |
-
image_state = gr.State(None)
|
99 |
-
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
|
100 |
-
chat_history = gr.State([])
|
101 |
-
|
102 |
-
with gr.Row():
|
103 |
-
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
|
104 |
-
|
105 |
-
with gr.Row():
|
106 |
-
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
|
107 |
-
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
108 |
-
|
109 |
-
def clear_all():
|
110 |
-
return [], None, None, ""
|
111 |
-
|
112 |
-
clear_btn = gr.Button("Start New Conversation")
|
113 |
-
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
|
114 |
-
|
115 |
-
def on_submit(user_input, display_history, new_image, persisted_image):
|
116 |
-
if not user_input.strip() and not new_image:
|
117 |
-
return display_history, persisted_image, None, ""
|
118 |
-
# The display history IS our history state now
|
119 |
-
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
|
120 |
-
|
121 |
-
submit_btn.click(
|
122 |
-
fn=on_submit,
|
123 |
-
inputs=[text_box, chat_history, image_box, image_state],
|
124 |
-
outputs=[chat_history, image_state, image_box, text_box]
|
125 |
-
)
|
126 |
-
text_box.submit(
|
127 |
-
fn=on_submit,
|
128 |
-
inputs=[text_box, chat_history, image_box, image_state],
|
129 |
-
outputs=[chat_history, image_state, image_box, text_box]
|
130 |
-
)
|
131 |
-
|
132 |
-
if __name__ == "__main__":
|
133 |
-
demo.launch(debug=True)
|
134 |
-
# Generate the response
|
135 |
-
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
136 |
-
|
137 |
-
# Decode only the newly generated part
|
138 |
-
input_token_len = inputs["input_ids"].shape[1]
|
139 |
-
generated_tokens = outputs[:, input_token_len:]
|
140 |
-
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
|
141 |
-
|
142 |
-
except Exception as e:
|
143 |
-
print(f"Caught a critical exception during generation: {e}", flush=True)
|
144 |
-
# Display the real error in the UI for easier debugging
|
145 |
-
clean_response = (
|
146 |
-
"An error occurred during generation. This is the technical details:\n\n"
|
147 |
-
f"```\n{type(e).__name__}: {e}\n```"
|
148 |
-
)
|
149 |
-
|
150 |
-
# Update the display history
|
151 |
-
history_for_display.append((user_input, clean_response))
|
152 |
-
|
153 |
-
# Return all updated values
|
154 |
-
return history_for_display, current_image, None, ""
|
155 |
-
|
156 |
# --- Gradio Interface ---
|
157 |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
158 |
gr.Markdown(
|
@@ -162,44 +82,34 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
162 |
"""
|
163 |
)
|
164 |
|
165 |
-
# State to hold the image across an entire conversation
|
166 |
image_state = gr.State(None)
|
167 |
-
|
168 |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
|
169 |
-
# The history state will now just be for display, a simple list of (text, text) tuples.
|
170 |
chat_history = gr.State([])
|
171 |
|
172 |
with gr.Row():
|
173 |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
|
174 |
|
175 |
with gr.Row():
|
176 |
-
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
|
177 |
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
178 |
|
179 |
-
# The clear function now resets all three states
|
180 |
def clear_all():
|
181 |
return [], None, None, ""
|
182 |
|
183 |
clear_btn = gr.Button("Start New Conversation")
|
184 |
-
clear_btn.click(
|
185 |
-
fn=clear_all,
|
186 |
-
outputs=[chat_history, image_state, image_box, text_box],
|
187 |
-
queue=False
|
188 |
-
)
|
189 |
|
190 |
-
# The submit handler function
|
191 |
def on_submit(user_input, display_history, new_image, persisted_image):
|
192 |
if not user_input.strip() and not new_image:
|
193 |
return display_history, persisted_image, None, ""
|
194 |
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
|
195 |
|
196 |
-
#
|
197 |
submit_btn.click(
|
198 |
fn=on_submit,
|
199 |
inputs=[text_box, chat_history, image_box, image_state],
|
200 |
outputs=[chat_history, image_state, image_box, text_box]
|
201 |
)
|
202 |
-
|
203 |
text_box.submit(
|
204 |
fn=on_submit,
|
205 |
inputs=[text_box, chat_history, image_box, image_state],
|
@@ -208,3 +118,4 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
208 |
|
209 |
if __name__ == "__main__":
|
210 |
demo.launch(debug=True)
|
|
|
|
8 |
# --- Configuration & Model Loading ---
|
9 |
# Use the pipeline, which is more robust as seen in the working example
|
10 |
print("Loading MedGemma model via pipeline...")
|
11 |
+
model_loaded = False
|
12 |
+
pipe = None
|
13 |
try:
|
14 |
pipe = pipeline(
|
15 |
+
"image-to-text",
|
16 |
model="google/medgemma-4b-it",
|
17 |
+
model_kwargs={"torch_dtype": torch.bfloat16},
|
18 |
device_map="auto",
|
19 |
token=os.environ.get("HF_TOKEN")
|
20 |
)
|
21 |
model_loaded = True
|
22 |
print("Model loaded successfully!")
|
23 |
except Exception as e:
|
|
|
24 |
print(f"Error loading model: {e}")
|
25 |
|
26 |
|
|
|
37 |
|
38 |
current_image = new_image_upload if new_image_upload is not None else image_state
|
39 |
|
40 |
+
# Build the 'messages' list using the correct format for the pipeline
|
|
|
41 |
messages = []
|
42 |
|
|
|
|
|
43 |
# Process the conversation history
|
44 |
for user_msg, assistant_msg in history_for_display:
|
45 |
+
messages.append({"role": "user", "content": user_msg})
|
|
|
|
|
46 |
if assistant_msg:
|
47 |
+
messages.append({"role": "assistant", "content": assistant_msg})
|
48 |
|
49 |
# Add the current user turn
|
50 |
+
messages.append({"role": "user", "content": user_input})
|
|
|
|
|
|
|
|
|
|
|
51 |
|
52 |
try:
|
53 |
+
# The pipeline call is simpler. We pass the image as the main argument
|
54 |
+
# and the text conversation as the `prompt`.
|
55 |
if current_image:
|
56 |
+
# The image goes first, the prompt kwarg contains the conversation history
|
57 |
+
output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512, "do_sample": True, "temperature": 0.7})
|
58 |
else:
|
59 |
# If no image, the pipeline can work with just the prompt
|
60 |
+
output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512, "do_sample": True, "temperature": 0.7})
|
61 |
|
62 |
+
# The pipeline output structure contains the full conversation.
|
63 |
+
# We want the content of the last message, which is the model's reply.
|
64 |
+
clean_response = output[0]["generated_text"][-1]['content']
|
|
|
|
|
|
|
|
|
|
|
65 |
|
66 |
except Exception as e:
|
67 |
print(f"Caught a critical exception during generation: {e}", flush=True)
|
|
|
73 |
history_for_display.append((user_input, clean_response))
|
74 |
return history_for_display, current_image, None, ""
|
75 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
# --- Gradio Interface ---
|
77 |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
78 |
gr.Markdown(
|
|
|
82 |
"""
|
83 |
)
|
84 |
|
|
|
85 |
image_state = gr.State(None)
|
|
|
86 |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
|
|
|
87 |
chat_history = gr.State([])
|
88 |
|
89 |
with gr.Row():
|
90 |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
|
91 |
|
92 |
with gr.Row():
|
93 |
+
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4)
|
94 |
submit_btn = gr.Button("Send", variant="primary", scale=1)
|
95 |
|
|
|
96 |
def clear_all():
|
97 |
return [], None, None, ""
|
98 |
|
99 |
clear_btn = gr.Button("Start New Conversation")
|
100 |
+
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
|
|
|
|
|
|
|
|
|
101 |
|
|
|
102 |
def on_submit(user_input, display_history, new_image, persisted_image):
|
103 |
if not user_input.strip() and not new_image:
|
104 |
return display_history, persisted_image, None, ""
|
105 |
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
|
106 |
|
107 |
+
# Event Handlers for submit button and enter key
|
108 |
submit_btn.click(
|
109 |
fn=on_submit,
|
110 |
inputs=[text_box, chat_history, image_box, image_state],
|
111 |
outputs=[chat_history, image_state, image_box, text_box]
|
112 |
)
|
|
|
113 |
text_box.submit(
|
114 |
fn=on_submit,
|
115 |
inputs=[text_box, chat_history, image_box, image_state],
|
|
|
118 |
|
119 |
if __name__ == "__main__":
|
120 |
demo.launch(debug=True)
|
121 |
+
|