import gradio as gr from transformers import pipeline from PIL import Image import torch import os import spaces # --- Initialize the Model Pipeline (No changes) --- print("Loading MedGemma model...") try: pipe = pipeline( "image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, device_map="auto", token=os.environ.get("HF_TOKEN") ) model_loaded = True print("Model loaded successfully!") except Exception as e: model_loaded = False print(f"Error loading model: {e}") # --- Core CONVERSATIONAL Logic (No changes) --- @spaces.GPU() def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list): """ Manages a single conversation turn and streams the AI response back. """ if not model_loaded: history[-1] = (user_input, "Error: The AI model is not loaded.") yield history, history return try: system_prompt = ( "You are an expert, empathetic AI medical assistant conducting a virtual consultation. Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. Do NOT provide a diagnosis or a list of possibilities right away. Ask only one or two focused questions per turn. If the user provides an image, your first step is to analyze it from an expert perspective. Briefly describe the key findings from the image. Then, use this analysis to ask relevant follow-up questions about the user's symptoms or medical history to better understand the context. For example, after seeing a rash, you might say, 'I see a reddish rash with well-defined borders on the forearm. To help me understand more, could you tell me when you first noticed this? Is it itchy, painful, or does it have any other sensation?' After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan." ) generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7} ai_response = "" if user_image: messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] for user_msg, assistant_msg in history[:-1]: messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) if assistant_msg: messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) latest_user_content = [{"type": "text", "text": user_input}, {"type": "image", "image": user_image}] messages.append({"role": "user", "content": latest_user_content}) output = pipe(text=messages, **generation_args) ai_response = output[0]["generated_text"][-1]["content"] else: prompt_parts = [f"system\n{system_prompt}"] for user_msg, assistant_msg in history[:-1]: prompt_parts.append(f"user\n{user_msg}") if assistant_msg: prompt_parts.append(f"model\n{assistant_msg}") prompt_parts.append(f"user\n{user_input}") prompt_parts.append("model") prompt = "\n".join(prompt_parts) output = pipe(prompt, **generation_args) full_text = output[0]["generated_text"] ai_response = full_text.split("model")[-1].strip() history[-1] = (user_input, "") for character in ai_response: history[-1] = (user_input, history[-1][1] + character) yield history, history except Exception as e: error_message = f"An error occurred: {str(e)}" history[-1] = (user_input, error_message) print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}") yield history, history # --- UI MODIFICATION: Final CSS for a clean, white background --- css = """ /* Set a base background for the overall app if needed */ .gradio-container { background-color: #f9fafb !important; } /* THIS IS THE KEY FIX: Target the chat container directly for a white background */ #chat-container { background-color: #ffffff !important; flex-grow: 1; overflow-y: auto; padding: 1rem; border: 1px solid #e5e7eb; border-radius: 12px; margin-bottom: 120px !important; /* Space for the sticky footer */ } /* Chat Bubble Styling */ .user > .message-bubble-row .message-bubble { background-color: #dbeafe !important; color: #1f2937 !important; border-top-right-radius: 5px !important; } .bot > .message-bubble-row .message-bubble { background-color: #f3f4f6 !important; color: #1f2937 !important; border-top-left-radius: 5px !important; } /* Sticky Footer Input Bar */ #footer-container { position: fixed !important; bottom: 0; left: 0; width: 100%; background-color: #ffffff !important; border-top: 1px solid #e5e7eb !important; padding: 1rem !important; z-index: 1000; } /* Text Input Box Styling */ #user-textbox textarea { background-color: #f9fafb !important; border: 1px solid #d1d5db !important; border-radius: 10px !important; color: #111827 !important; } /* Icon Button General Styling */ .icon-btn { min-width: 50px !important; max-width: 50px !important; height: 50px !important; font-size: 1.5rem !important; border: none !important; border-radius: 10px !important; } /* Specific Icon Button Colors */ #restart-btn { background-color: #fee2e2 !important; color: #ef4444 !important; } #upload-btn, #send-btn { background-color: #3b82f6 !important; color: white !important; } """ # Reverting to gr.themes.Soft() for a reliable light theme foundation with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation", css=css) as demo: conversation_history = gr.State([]) # The main content area now has the specific ID for styling. with gr.Column(elem_id="chat-container"): chatbot_display = gr.Chatbot(elem_id="chatbot", label="Consultation", show_copy_button=True, bubble_full_width=False, avatar_images=("./images/user.png", "./images/bot.png")) # The sticky footer for inputs with gr.Column(elem_id="footer-container"): with gr.Row(): image_preview = gr.Image(type="pil", height=60, width=60, visible=False, show_label=False, container=False, scale=1) upload_button = gr.UploadButton("📷", file_types=["image"], elem_id="upload-btn", elem_classes="icon-btn", scale=1) user_textbox = gr.Textbox( elem_id="user-textbox", placeholder="Type your message, or upload an image...", show_label=False, scale=5, container=False ) send_button = gr.Button("➤", elem_id="send-btn", elem_classes="icon-btn", scale=1) clear_button = gr.Button("🔄", elem_id="restart-btn", elem_classes="icon-btn", scale=1) def show_image_preview(image_file): return gr.Image(value=image_file, visible=True) upload_button.upload(fn=show_image_preview, inputs=upload_button, outputs=image_preview) def submit_message_and_stream(user_input: str, user_image: Image.Image, history: list): if not user_input.strip() and user_image is None: return history, history, None history.append((user_input, None)) yield history, history, None for updated_history, new_state in handle_conversation_turn(user_input, user_image, history): yield updated_history, new_state def clear_inputs(): return "", None send_button.click( fn=submit_message_and_stream, inputs=[user_textbox, image_preview, conversation_history], outputs=[chatbot_display, conversation_history], ).then(fn=clear_inputs, outputs=[user_textbox, image_preview]) user_textbox.submit( fn=submit_message_and_stream, inputs=[user_textbox, image_preview, conversation_history], outputs=[chatbot_display, conversation_history], ).then(fn=clear_inputs, outputs=[user_textbox, image_preview]) clear_button.click( lambda: ([], [], None, ""), outputs=[chatbot_display, conversation_history, image_preview, user_textbox] ) if __name__ == "__main__": print("Starting Gradio interface...") demo.launch(debug=True)