hponepyae's picture
Update app.py
77793f4 verified
raw
history blame
8.61 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Configuration & Model Loading ---
# Use the pipeline, which is more robust as seen in the working example
print("Loading MedGemma model via pipeline...")
try:
pipe = pipeline(
"image-to-text", # The correct task for this model
model="google/medgemma-4b-it",
model_kwargs={"torch_dtype": torch.bfloat16}, # Pass dtype here
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
model_loaded = False
print(f"Error loading model: {e}")
# --- Core Chatbot Function ---
@spaces.GPU
def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state):
"""
Manages the conversation using the correct message format derived from the working example.
"""
if not model_loaded:
if user_input:
history_for_display.append((user_input, "Error: The model could not be loaded."))
return history_for_display, image_state, None, ""
current_image = new_image_upload if new_image_upload is not None else image_state
# --- THE CORRECT IMPLEMENTATION ---
# Build the 'messages' list using the exact format from the working X-ray app.
messages = []
# Optional: System prompt can be added here if needed, following the same format.
# Process the conversation history
for user_msg, assistant_msg in history_for_display:
# For history turns, we assume the image was part of the first turn (handled below).
# So, all historical messages are just text.
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
if assistant_msg:
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
# Add the current user turn
current_user_content = [{"type": "text", "text": user_input}]
# If there's an image for the conversation, add it to the first user turn's content
if current_image is not None and not history_for_display: # Only for the very first message
current_user_content.append({"type": "image"}) # The pipeline handles the image object separately
messages.append({"role": "user", "content": current_user_content})
try:
# Generate analysis using the pipeline. It's much simpler.
# We pass the image separately if it exists.
if current_image:
output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512})
else:
# If no image, the pipeline can work with just the prompt
output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512})
# The pipeline's output structure can be complex; we need to extract the final text.
# It's usually in the last dictionary of the generated list.
result = output[0]["generated_text"]
if isinstance(result, list):
# Find the last text content from the model's response
clean_response = next((item['text'] for item in reversed(result) if item['type'] == 'text'), "Sorry, I couldn't generate a response.")
else: # Simpler text-only output
clean_response = result
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
clean_response = (
"An error occurred during generation. Details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
history_for_display.append((user_input, clean_response))
return history_for_display, current_image, None, ""
# --- Gradio Interface (Mostly unchanged) ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False)
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
# The display history IS our history state now
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
# Generate the response
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
# Decode only the newly generated part
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
# Display the real error in the UI for easier debugging
clean_response = (
"An error occurred during generation. This is the technical details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
# Update the display history
history_for_display.append((user_input, clean_response))
# Return all updated values
return history_for_display, current_image, None, ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed.
"""
)
# State to hold the image across an entire conversation
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png"))
# The history state will now just be for display, a simple list of (text, text) tuples.
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
# The clear function now resets all three states
def clear_all():
return [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(
fn=clear_all,
outputs=[chat_history, image_state, image_box, text_box],
queue=False
)
# The submit handler function
def on_submit(user_input, display_history, new_image, persisted_image):
if not user_input.strip() and not new_image:
return display_history, persisted_image, None, ""
return symptom_checker_chat(user_input, display_history, new_image, persisted_image)
# Wire up the events
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)