Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
import os | |
import spaces | |
# --- Configuration & Model Loading --- | |
# Use the pipeline, which is more robust as seen in the working example | |
print("Loading MedGemma model via pipeline...") | |
try: | |
pipe = pipeline( | |
"image-to-text", # The correct task for this model | |
model="google/medgemma-4b-it", | |
model_kwargs={"torch_dtype": torch.bfloat16}, # Pass dtype here | |
device_map="auto", | |
token=os.environ.get("HF_TOKEN") | |
) | |
model_loaded = True | |
print("Model loaded successfully!") | |
except Exception as e: | |
model_loaded = False | |
print(f"Error loading model: {e}") | |
# --- Core Chatbot Function --- | |
def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state): | |
""" | |
Manages the conversation using the correct message format derived from the working example. | |
""" | |
if not model_loaded: | |
if user_input: | |
history_for_display.append((user_input, "Error: The model could not be loaded.")) | |
return history_for_display, image_state, None, "" | |
current_image = new_image_upload if new_image_upload is not None else image_state | |
# --- THE CORRECT IMPLEMENTATION --- | |
# Build the 'messages' list using the exact format from the working X-ray app. | |
messages = [] | |
# Optional: System prompt can be added here if needed, following the same format. | |
# Process the conversation history | |
for user_msg, assistant_msg in history_for_display: | |
# For history turns, we assume the image was part of the first turn (handled below). | |
# So, all historical messages are just text. | |
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) | |
if assistant_msg: | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
# Add the current user turn | |
current_user_content = [{"type": "text", "text": user_input}] | |
# If there's an image for the conversation, add it to the first user turn's content | |
if current_image is not None and not history_for_display: # Only for the very first message | |
current_user_content.append({"type": "image"}) # The pipeline handles the image object separately | |
messages.append({"role": "user", "content": current_user_content}) | |
try: | |
# Generate analysis using the pipeline. It's much simpler. | |
# We pass the image separately if it exists. | |
if current_image: | |
output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512}) | |
else: | |
# If no image, the pipeline can work with just the prompt | |
output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512}) | |
# The pipeline's output structure can be complex; we need to extract the final text. | |
# It's usually in the last dictionary of the generated list. | |
result = output[0]["generated_text"] | |
if isinstance(result, list): | |
# Find the last text content from the model's response | |
clean_response = next((item['text'] for item in reversed(result) if item['type'] == 'text'), "Sorry, I couldn't generate a response.") | |
else: # Simpler text-only output | |
clean_response = result | |
except Exception as e: | |
print(f"Caught a critical exception during generation: {e}", flush=True) | |
clean_response = ( | |
"An error occurred during generation. Details:\n\n" | |
f"```\n{type(e).__name__}: {e}\n```" | |
) | |
history_for_display.append((user_input, clean_response)) | |
return history_for_display, current_image, None, "" | |
# --- Gradio Interface (Mostly unchanged) --- | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: | |
gr.Markdown( | |
""" | |
# AI Symptom Checker powered by MedGemma | |
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed. | |
""" | |
) | |
image_state = gr.State(None) | |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False) | |
chat_history = gr.State([]) | |
with gr.Row(): | |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") | |
with gr.Row(): | |
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4) | |
submit_btn = gr.Button("Send", variant="primary", scale=1) | |
def clear_all(): | |
return [], None, None, "" | |
clear_btn = gr.Button("Start New Conversation") | |
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False) | |
def on_submit(user_input, display_history, new_image, persisted_image): | |
if not user_input.strip() and not new_image: | |
return display_history, persisted_image, None, "" | |
# The display history IS our history state now | |
return symptom_checker_chat(user_input, display_history, new_image, persisted_image) | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[text_box, chat_history, image_box, image_state], | |
outputs=[chat_history, image_state, image_box, text_box] | |
) | |
text_box.submit( | |
fn=on_submit, | |
inputs=[text_box, chat_history, image_box, image_state], | |
outputs=[chat_history, image_state, image_box, text_box] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |
# Generate the response | |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) | |
# Decode only the newly generated part | |
input_token_len = inputs["input_ids"].shape[1] | |
generated_tokens = outputs[:, input_token_len:] | |
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip() | |
except Exception as e: | |
print(f"Caught a critical exception during generation: {e}", flush=True) | |
# Display the real error in the UI for easier debugging | |
clean_response = ( | |
"An error occurred during generation. This is the technical details:\n\n" | |
f"```\n{type(e).__name__}: {e}\n```" | |
) | |
# Update the display history | |
history_for_display.append((user_input, clean_response)) | |
# Return all updated values | |
return history_for_display, current_image, None, "" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: | |
gr.Markdown( | |
""" | |
# AI Symptom Checker powered by MedGemma | |
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed. | |
""" | |
) | |
# State to hold the image across an entire conversation | |
image_state = gr.State(None) | |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png")) | |
# The history state will now just be for display, a simple list of (text, text) tuples. | |
chat_history = gr.State([]) | |
with gr.Row(): | |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") | |
with gr.Row(): | |
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4) | |
submit_btn = gr.Button("Send", variant="primary", scale=1) | |
# The clear function now resets all three states | |
def clear_all(): | |
return [], None, None, "" | |
clear_btn = gr.Button("Start New Conversation") | |
clear_btn.click( | |
fn=clear_all, | |
outputs=[chat_history, image_state, image_box, text_box], | |
queue=False | |
) | |
# The submit handler function | |
def on_submit(user_input, display_history, new_image, persisted_image): | |
if not user_input.strip() and not new_image: | |
return display_history, persisted_image, None, "" | |
return symptom_checker_chat(user_input, display_history, new_image, persisted_image) | |
# Wire up the events | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[text_box, chat_history, image_box, image_state], | |
outputs=[chat_history, image_state, image_box, text_box] | |
) | |
text_box.submit( | |
fn=on_submit, | |
inputs=[text_box, chat_history, image_box, image_state], | |
outputs=[chat_history, image_state, image_box, text_box] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |