Spaces:
Running
Running
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
import os | |
import spaces | |
# --- Initialize the Model Pipeline (No changes) --- | |
print("Loading MedGemma model...") | |
try: | |
pipe = pipeline( | |
"image-text-to-text", | |
model="google/medgemma-4b-it", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
token=os.environ.get("HF_TOKEN") | |
) | |
model_loaded = True | |
print("Model loaded successfully!") | |
except Exception as e: | |
model_loaded = False | |
print(f"Error loading model: {e}") | |
# --- Core CONVERSATIONAL Logic (Modified for Streaming) --- | |
def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list): | |
""" | |
Manages a single conversation turn and streams the AI response back. | |
This function is now a Python generator. | |
""" | |
if not model_loaded: | |
history[-1] = (user_input, "Error: The AI model is not loaded.") | |
yield history, history, None | |
return | |
try: | |
system_prompt = ( | |
"You are an expert, empathetic AI medical assistant conducting a virtual consultation. " | |
"Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. " | |
"Do NOT provide a diagnosis or a list of possibilities right away. Ask only one or two focused questions per turn. " | |
"If the user provides an image, your first step is to analyze it from an expert perspective. Briefly describe the key findings from the image. " | |
"Then, use this analysis to ask relevant follow-up questions about the user's symptoms or medical history to better understand the context. " | |
"For example, after seeing a rash, you might say, 'I see a reddish rash with well-defined borders on the forearm. To help me understand more, could you tell me when you first noticed this? Is it itchy, painful, or does it have any other sensation?'" | |
"After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. " | |
"THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan." | |
) | |
generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7} | |
ai_response = "" | |
if user_image: | |
# ... (logic remains the same) | |
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}] | |
for user_msg, assistant_msg in history[:-1]: | |
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) | |
if assistant_msg: messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) | |
latest_user_content = [{"type": "text", "text": user_input}, {"type": "image", "image": user_image}] | |
messages.append({"role": "user", "content": latest_user_content}) | |
output = pipe(text=messages, **generation_args) | |
ai_response = output[0]["generated_text"][-1]["content"] | |
else: | |
# ... (logic remains the same) | |
prompt_parts = [f"<start_of_turn>system\n{system_prompt}"] | |
for user_msg, assistant_msg in history[:-1]: | |
prompt_parts.append(f"<start_of_turn>user\n{user_msg}") | |
if assistant_msg: prompt_parts.append(f"<start_of_turn>model\n{assistant_msg}") | |
prompt_parts.append(f"<start_of_turn>user\n{user_input}") | |
prompt_parts.append("<start_of_turn>model") | |
prompt = "\n".join(prompt_parts) | |
output = pipe(prompt, **generation_args) | |
full_text = output[0]["generated_text"] | |
ai_response = full_text.split("<start_of_turn>model")[-1].strip() | |
# Stream the AI response back to the chatbot | |
history[-1] = (user_input, "") | |
for character in ai_response: | |
history[-1] = (user_input, history[-1][1] + character) | |
yield history, history, None | |
except Exception as e: | |
error_message = f"An error occurred: {str(e)}" | |
history[-1] = (user_input, error_message) | |
print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}") | |
yield history, history, None | |
# --- UI MODIFICATION: Professional CSS for the chat interface --- | |
css = """ | |
/* Make the main app container fill the screen height */ | |
div.gradio-container { height: 100vh !important; } | |
/* Main chat area styling */ | |
#chat-container { flex-grow: 1; overflow-y: auto; padding-bottom: 120px; } | |
/* Sticky footer for the input bar */ | |
#footer-container { | |
position: fixed !important; bottom: 0; left: 0; width: 100%; | |
background-color: #e0f2fe !important; /* Light Sky Blue background */ | |
border-top: 1px solid #bae6fd !important; | |
padding: 10px; z-index: 1000; | |
} | |
/* White, rounded textbox */ | |
#user-textbox textarea { | |
background-color: #ffffff !important; | |
border: 1px solid #cbd5e1 !important; | |
border-radius: 8px !important; | |
} | |
/* Style the image upload button */ | |
#image-upload-btn { border: 1px dashed #9ca3af !important; border-radius: 8px !important; } | |
""" | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation", css=css) as demo: | |
conversation_history = gr.State([]) | |
with gr.Column(elem_id="chat-container"): | |
chatbot_display = gr.Chatbot(label="Consultation", show_copy_button=True, bubble_full_width=False) | |
with gr.Column(elem_id="footer-container"): | |
with gr.Row(): | |
image_input = gr.Image(elem_id="image-upload-btn", label="Image", type="pil", height=80, show_label=False, container=False, scale=1) | |
user_textbox = gr.Textbox( | |
elem_id="user-textbox", | |
label="Your Message", | |
placeholder="Type your message, or upload an image...", | |
show_label=False, scale=4, container=False | |
) | |
send_button = gr.Button("Send", variant="primary", scale=1) | |
with gr.Row(): | |
clear_button = gr.Button("๐๏ธ Start New Consultation") | |
# This new function handles the full UX flow: instant feedback + streaming AI response | |
def submit_message_and_stream(user_input: str, user_image: Image.Image, history: list): | |
if not user_input.strip() and user_image is None: | |
# Do nothing if the input is empty | |
return history, history, None | |
# 1. Instantly add the user's message to the chat UI | |
history.append((user_input, None)) | |
yield history, history, None | |
# 2. Start the generator to get the AI's response stream | |
for updated_history, new_state, cleared_image in handle_conversation_turn(user_input, user_image, history): | |
yield updated_history, new_state, cleared_image | |
# --- Event Handlers --- | |
send_button.click( | |
fn=submit_message_and_stream, | |
inputs=[user_textbox, image_input, conversation_history], | |
outputs=[chatbot_display, conversation_history, image_input], | |
).then(lambda: "", outputs=user_textbox) # Clear textbox after submission | |
user_textbox.submit( | |
fn=submit_message_and_stream, | |
inputs=[user_textbox, image_input, conversation_history], | |
outputs=[chatbot_display, conversation_history, image_input], | |
).then(lambda: "", outputs=user_textbox) | |
clear_button.click( | |
lambda: ([], [], None, ""), | |
outputs=[chatbot_display, conversation_history, image_input, user_textbox] | |
) | |
if __name__ == "__main__": | |
print("Starting Gradio interface...") | |
demo.launch(debug=True) |