Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
import os | |
import spaces | |
# --- Initialize the Model Pipeline --- | |
print("Loading MedGemma model...") | |
try: | |
pipe = pipeline( | |
"image-text-to-text", | |
model="google/medgemma-4b-it", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
token=os.environ.get("HF_TOKEN") | |
) | |
model_loaded = True | |
print("Model loaded successfully!") | |
except Exception as e: | |
model_loaded = False | |
print(f"Error loading model: {e}") | |
# --- Core CONVERSATIONAL Logic --- | |
def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list): | |
""" | |
Manages a single conversation turn with corrected state-management logic. | |
""" | |
if not model_loaded: | |
history.append((user_input, "Error: The AI model is not loaded.")) | |
# *** THE FIX: Return history twice to update both UI and State *** | |
return history, history, None | |
try: | |
system_prompt = ( | |
"You are an expert, empathetic AI medical assistant conducting a virtual consultation. " | |
"Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. " | |
"Do NOT provide a diagnosis or a list of possibilities right away. Ask only one or two focused questions per turn. " | |
"If the user provides an image, your first step is to analyze it from an expert perspective. Briefly describe the key findings from the image. " | |
"Then, use this analysis to ask relevant follow-up questions about the user's symptoms or medical history to better understand the context. " | |
"For example, after seeing a rash, you might say, 'I see a reddish rash with well-defined borders on the forearm. To help me understand more, could you tell me when you first noticed this? Is it itchy, painful, or does it have any other sensation?'" | |
"After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. " | |
"THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan." | |
) | |
generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7} | |
ai_response = "" | |
if user_image: | |
print("Image detected. Using multimodal 'messages' format...") | |
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] | |
for user_msg, assistant_msg in history: | |
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) | |
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
latest_user_content = [{"type": "text", "text": user_input}, {"type": "image", "image": user_image}] | |
messages.append({"role": "user", "content": latest_user_content}) | |
output = pipe(text=messages, **generation_args) | |
ai_response = output[0]["generated_text"][-1]["content"] | |
else: | |
print("No image detected. Using robust 'text-only' format.") | |
prompt_parts = [f"<start_of_turn>system\n{system_prompt}"] | |
for user_msg, assistant_msg in history: | |
prompt_parts.append(f"<start_of_turn>user\n{user_msg}") | |
prompt_parts.append(f"<start_of_turn>model\n{assistant_msg}") | |
prompt_parts.append(f"<start_of_turn>user\n{user_input}") | |
prompt_parts.append("<start_of_turn>model") | |
prompt = "\n".join(prompt_parts) | |
output = pipe(prompt, **generation_args) | |
full_text = output[0]["generated_text"] | |
ai_response = full_text.split("<start_of_turn>model")[-1].strip() | |
history.append((user_input, ai_response)) | |
# *** THE FIX: Return history twice to update both UI and State *** | |
return history, history, None | |
except Exception as e: | |
history.append((user_input, f"An error occurred: {str(e)}")) | |
print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}") | |
# *** THE FIX: Return history twice to update both UI and State *** | |
return history, history, None | |
# --- Gradio Interface --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation") as demo: | |
# The gr.State object that holds the conversation history. | |
conversation_history = gr.State([]) | |
gr.HTML(""" | |
<div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;"> | |
<h1>🩺 AI Symptom Consultation</h1> | |
<p>A conversational AI to help you understand your symptoms, powered by Google's MedGemma</p> | |
</div> | |
""") | |
gr.HTML(""" | |
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;"> | |
<strong>⚠️ Medical Disclaimer:</strong> This is not a diagnosis. This AI is for informational purposes and is not a substitute for professional medical advice. | |
</div> | |
""") | |
chatbot_display = gr.Chatbot(height=500, label="Consultation") | |
with gr.Row(): | |
image__input = gr.Image(label="Upload Symptom Image (Optional)", type="pil", height=150) | |
with gr.Column(scale=4): | |
user_textbox = gr.Textbox(label="Your Message", placeholder="Describe your primary symptom to begin...", lines=4) | |
send_button = gr.Button("Send Message", variant="primary") | |
# The submit action now has `conversation_history` as both an input and an output. | |
send_button.click( | |
fn=handle_conversation_turn, | |
inputs=[user_textbox, image__input, conversation_history], | |
# *** THE FIX: Add `conversation_history` to the outputs list *** | |
outputs=[chatbot_display, conversation_history, image__input] | |
).then( | |
lambda: "", outputs=user_textbox | |
) | |
clear_button = gr.Button("🗑️ Start New Consultation") | |
clear_button.click( | |
lambda: ([], [], None, ""), | |
outputs=[chatbot_display, conversation_history, image__input, user_textbox] | |
) | |
if __name__ == "__main__": | |
print("Starting Gradio interface...") | |
demo.launch(debug=True) |