hponepyae's picture
Update app.py
f96636d verified
raw
history blame
6.83 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Initialize the Model Pipeline ---
print("Loading MedGemma model...")
try:
pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
model_loaded = False
print(f"Error loading model: {e}")
# --- Core CONVERSATIONAL Logic ---
@spaces.GPU()
def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list):
"""
Manages a single turn of the conversation with an improved, role-specific system prompt.
"""
if not model_loaded:
history.append((user_input, "Error: The AI model is not loaded. Please contact the administrator."))
return history, None
try:
# *** THE FIX: A much more specific and intelligent system prompt ***
system_prompt = (
"You are an expert, empathetic AI medical assistant conducting a virtual consultation. "
"Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. "
"Do NOT provide a diagnosis or a list of possibilities right away. Ask only one or two focused questions per turn. "
# This is the new, crucial instruction for image handling:
"If the user provides an image, your first step is to analyze it from an expert perspective. Briefly describe the key findings from the image. "
"Then, use this analysis to ask relevant follow-up questions about the user's symptoms or medical history to better understand the context. "
"For example, after seeing a rash, you might say, 'I see a reddish rash with well-defined borders on the forearm. To help me understand more, could you tell me when you first noticed this? Is it itchy, painful, or does it have any other sensation?'"
# This is the instruction for the final step:
"After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. "
"THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
)
generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7}
ai_response = ""
# The two-path logic for image vs. text-only remains the same, as it is robust.
if user_image:
print("Image detected. Using multimodal 'messages' format...")
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
latest_user_content = []
if user_input:
latest_user_content.append({"type": "text", "text": user_input})
latest_user_content.append({"type": "image", "image": user_image})
messages.append({"role": "user", "content": latest_user_content})
output = pipe(text=messages, **generation_args)
ai_response = output[0]["generated_text"][-1]["content"]
else:
print("No image detected. Using robust 'text-only' format...")
prompt_parts = [f"<start_of_turn>system\n{system_prompt}<start_of_turn>"]
for user_msg, assistant_msg in history:
prompt_parts.append(f"user\n{user_msg}<start_of_turn>")
prompt_parts.append(f"model\n{assistant_msg}<start_of_turn>")
prompt_parts.append(f"user\n{user_input}<start_of_turn>model")
prompt = "".join(prompt_parts)
output = pipe(prompt, **generation_args)
full_text = output[0]["generated_text"]
ai_response = full_text.split("<start_of_turn>model")[-1].strip()
history.append((user_input, ai_response))
return history, None
except Exception as e:
history.append((user_input, f"An error occurred: {str(e)}"))
print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}")
return history, None
# --- Gradio Interface (No changes needed) ---
with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation") as demo:
conversation_history = gr.State([])
gr.HTML("""
<div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
<h1>🩺 AI Symptom Consultation</h1>
<p>A conversational AI to help you understand your symptoms, powered by Google's MedGemma</p>
</div>
""")
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
<strong>⚠️ Medical Disclaimer:</strong> This is not a diagnosis. This AI is for informational purposes and is not a substitute for professional medical advice.
</div>
""")
chatbot_display = gr.Chatbot(height=500, label="Consultation")
with gr.Row():
image_input = gr.Image(label="Upload Symptom Image (Optional)", type="pil", height=150)
with gr.Column(scale=4):
user_textbox = gr.Textbox(label="Your Message", placeholder="Describe your primary symptom to begin...", lines=4)
send_button = gr.Button("Send Message", variant="primary")
def submit_message(user_input, user_image, history):
# This wrapper calls the main logic and then clears the user's input fields.
updated_history, cleared_image = handle_conversation_turn(user_input, user_image, history)
return updated_history, cleared_image
# The submit action
send_button.click(
fn=submit_message,
inputs=[user_textbox, image_input, conversation_history],
outputs=[chatbot_display, image_input]
).then(
# Clear the user's text box after the message is sent.
lambda: "",
outputs=user_textbox
)
# Add a clear button for convenience
clear_button = gr.Button("🗑️ Start New Consultation")
clear_button.click(lambda: ([], [], None, ""), outputs=[chatbot_display, conversation_history, image_input, user_textbox])
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch(debug=True)