import gradio as gr from transformers import pipeline from PIL import Image import torch import os import spaces # --- Initialize the Model Pipeline --- print("Loading MedGemma model...") try: pipe = pipeline( "image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, device_map="auto", token=os.environ.get("HF_TOKEN") ) model_loaded = True print("Model loaded successfully!") except Exception as e: model_loaded = False print(f"Error loading model: {e}") # --- Core CONVERSATIONAL Logic --- @spaces.GPU() def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list): """ Manages a single turn of the conversation with an improved, role-specific system prompt. """ if not model_loaded: history.append((user_input, "Error: The AI model is not loaded. Please contact the administrator.")) return history, None try: # *** THE FIX: A much more specific and intelligent system prompt *** system_prompt = ( "You are an expert, empathetic AI medical assistant conducting a virtual consultation. " "Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. " "Do NOT provide a diagnosis or a list of possibilities right away. Ask only one or two focused questions per turn. " # This is the new, crucial instruction for image handling: "If the user provides an image, your first step is to analyze it from an expert perspective. Briefly describe the key findings from the image. " "Then, use this analysis to ask relevant follow-up questions about the user's symptoms or medical history to better understand the context. " "For example, after seeing a rash, you might say, 'I see a reddish rash with well-defined borders on the forearm. To help me understand more, could you tell me when you first noticed this? Is it itchy, painful, or does it have any other sensation?'" # This is the instruction for the final step: "After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. " "THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan." ) generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7} ai_response = "" # The two-path logic for image vs. text-only remains the same, as it is robust. if user_image: print("Image detected. Using multimodal 'messages' format...") messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] for user_msg, assistant_msg in history: messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) latest_user_content = [] if user_input: latest_user_content.append({"type": "text", "text": user_input}) latest_user_content.append({"type": "image", "image": user_image}) messages.append({"role": "user", "content": latest_user_content}) output = pipe(text=messages, **generation_args) ai_response = output[0]["generated_text"][-1]["content"] else: print("No image detected. Using robust 'text-only' format...") prompt_parts = [f"system\n{system_prompt}"] for user_msg, assistant_msg in history: prompt_parts.append(f"user\n{user_msg}") prompt_parts.append(f"model\n{assistant_msg}") prompt_parts.append(f"user\n{user_input}model") prompt = "".join(prompt_parts) output = pipe(prompt, **generation_args) full_text = output[0]["generated_text"] ai_response = full_text.split("model")[-1].strip() history.append((user_input, ai_response)) return history, None except Exception as e: history.append((user_input, f"An error occurred: {str(e)}")) print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}") return history, None # --- Gradio Interface (No changes needed) --- with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation") as demo: conversation_history = gr.State([]) gr.HTML("""

🩺 AI Symptom Consultation

A conversational AI to help you understand your symptoms, powered by Google's MedGemma

""") gr.HTML("""
⚠️ Medical Disclaimer: This is not a diagnosis. This AI is for informational purposes and is not a substitute for professional medical advice.
""") chatbot_display = gr.Chatbot(height=500, label="Consultation") with gr.Row(): image_input = gr.Image(label="Upload Symptom Image (Optional)", type="pil", height=150) with gr.Column(scale=4): user_textbox = gr.Textbox(label="Your Message", placeholder="Describe your primary symptom to begin...", lines=4) send_button = gr.Button("Send Message", variant="primary") def submit_message(user_input, user_image, history): # This wrapper calls the main logic and then clears the user's input fields. updated_history, cleared_image = handle_conversation_turn(user_input, user_image, history) return updated_history, cleared_image # The submit action send_button.click( fn=submit_message, inputs=[user_textbox, image_input, conversation_history], outputs=[chatbot_display, image_input] ).then( # Clear the user's text box after the message is sent. lambda: "", outputs=user_textbox ) # Add a clear button for convenience clear_button = gr.Button("🗑️ Start New Consultation") clear_button.click(lambda: ([], [], None, ""), outputs=[chatbot_display, conversation_history, image_input, user_textbox]) if __name__ == "__main__": print("Starting Gradio interface...") demo.launch(debug=True)