Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
import os | |
import spaces | |
# --- Configuration & Model Loading --- | |
# Use the pipeline, which is more robust as seen in the working example | |
print("Loading MedGemma model via pipeline...") | |
model_loaded = False | |
pipe = None | |
try: | |
pipe = pipeline( | |
"image-to-text", | |
model="google/medgemma-4b-it", | |
model_kwargs={"torch_dtype": torch.bfloat16}, | |
device_map="auto", | |
token=os.environ.get("HF_TOKEN") | |
) | |
model_loaded = True | |
print("Model loaded successfully!") | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# --- Core Chatbot Function --- | |
# Increase timeout duration for long first-time generation | |
def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state): | |
""" | |
Manages the conversation by embedding the image directly into the message structure, | |
which is the correct way to use this pipeline and prevents hanging. | |
""" | |
if not model_loaded: | |
if user_input: | |
history_for_display.append((user_input, "Error: The model could not be loaded.")) | |
return history_for_display, image_state, None, "" | |
current_image = new_image_upload if new_image_upload is not None else image_state | |
# --- THE CORRECT IMPLEMENTATION --- | |
# Build the 'messages' list by embedding the image object directly inside the content. | |
messages = [] | |
# Reconstruct the conversation from history. | |
for i, (user_msg, assistant_msg) in enumerate(history_for_display): | |
# We define the content for the user's turn | |
user_content = [{"type": "text", "text": user_msg}] | |
# If it's the very first turn of the conversation AND an image exists for it, | |
# we embed the image object here. | |
if i == 0 and current_image is not None: | |
user_content.append({"type": "image", "image": current_image}) | |
messages.append({"role": "user", "content": user_content}) | |
if assistant_msg: | |
# The assistant's response is always text | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
# Add the current user's input to the conversation | |
current_user_content = [{"type": "text", "text": user_input}] | |
# If this is the start of a NEW conversation (no history) AND an image was just uploaded, | |
# embed the image object in this first turn. | |
if not history_for_display and current_image is not None: | |
current_user_content.append({"type": "image", "image": current_image}) | |
messages.append({"role": "user", "content": current_user_content}) | |
try: | |
# The pipeline call is now simple and correct. | |
# It ONLY takes the `messages` structure. The pipeline unpacks it internally. | |
output = pipe(messages, max_new_tokens=512, do_sample=True, temperature=0.7) | |
# The pipeline returns the full conversation. The last message is the model's reply. | |
clean_response = output[0]["generated_text"][-1]['content'] | |
except Exception as e: | |
print(f"Caught a critical exception during generation: {e}", flush=True) | |
clean_response = ( | |
"An error occurred during generation. Details:\n\n" | |
f"```\n{type(e).__name__}: {e}\n```" | |
) | |
# Update history and return values for Gradio UI | |
history_for_display.append((user_input, clean_response)) | |
return history_for_display, current_image, None, "" | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: | |
gr.Markdown( | |
""" | |
# AI Symptom Checker powered by MedGemma | |
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. | |
""" | |
) | |
image_state = gr.State(None) | |
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False) | |
chat_history = gr.State([]) | |
with gr.Row(): | |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") | |
with gr.Row(): | |
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4) | |
submit_btn = gr.Button("Send", variant="primary", scale=1) | |
def clear_all(): | |
return [], None, None, "" | |
clear_btn = gr.Button("Start New Conversation") | |
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False) | |
def on_submit(user_input, display_history, new_image, persisted_image): | |
if not user_input.strip() and not new_image: | |
return display_history, persisted_image, None, "" | |
return symptom_checker_chat(user_input, display_history, new__image, persisted_image) | |
submit_btn.click( | |
fn=on_submit, | |
inputs=[text_box, chat_history, image_box, image_state], | |
outputs=[chat_history, image_state, image_box, text_box] | |
) | |
text_box.submit( | |
fn=on_submit, | |
inputs=[text_box, chat_history, image_box, image_state], | |
outputs=[chat_history, image_state, image_box, text_box] | |
) | |
if __name__ == "__main__": | |
demo.launch(debug=True) | |