Spaces:
Sleeping
Sleeping
File size: 6,340 Bytes
83ff66a c5882f3 b67fca4 77793f4 83ff66a b0565c1 83ff66a 50aaa9b c5882f3 83ff66a c5882f3 998c789 6ef5bdf 77793f4 83ff66a c5882f3 83ff66a 998c789 b67fca4 83ff66a 909352f 998c789 909352f b67fca4 f0fe4ce b67fca4 909352f 77793f4 c5882f3 909352f f0fe4ce 909352f 60665db d305e52 f0fe4ce 909352f f0fe4ce 909352f f0fe4ce 909352f f0fe4ce c5882f3 f0fe4ce 909352f 77793f4 909352f f0fe4ce 909352f 998c789 9f24600 909352f 998c789 909352f 998c789 909352f f0fe4ce 909352f f0fe4ce 909352f f0fe4ce 909352f f0fe4ce 909352f c5882f3 9c4076b 998c789 ed3187b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 |
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Initialize the Model Pipeline ---
print("Loading MedGemma model...")
try:
pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
model_loaded = False
print(f"Error loading model: {e}")
# --- Core CONVERSATIONAL Logic ---
@spaces.GPU()
def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list):
"""
Manages a single turn of the conversation, using separate logic for image and text-only inputs.
"""
if not model_loaded:
history.append((user_input, "Error: The AI model is not loaded. Please contact the administrator."))
return history, None
try:
system_prompt = (
"You are an expert, empathetic AI medical assistant conducting a virtual consultation. "
"Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. "
"Do NOT provide a diagnosis or a list of possibilities right away. "
"Your first step is ALWAYS to ask relevant follow-up questions. Ask only one or two focused questions per turn. "
"If the user provides an image, acknowledge it by describing what you see first, then ask your questions. "
"After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. "
"THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
)
generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7}
ai_response = ""
# --- THE FIX: We create two different logic paths ---
if user_image:
# --- PATH 1: Image is present. Use the proven 'messages' format. ---
print("Image detected. Using multimodal 'messages' format...")
messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
for user_msg, assistant_msg in history:
messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
latest_user_content = []
if user_input:
latest_user_content.append({"type": "text", "text": user_input})
latest_user_content.append({"type": "image", "image": user_image})
messages.append({"role": "user", "content": latest_user_content})
output = pipe(text=messages, **generation_args)
ai_response = output[0]["generated_text"][-1]["content"]
else:
# --- PATH 2: No image. Use a simple, robust prompt string. ---
print("No image detected. Using robust 'text-only' format...")
# Manually build a single string representing the entire conversation
prompt_parts = [f"<start_of_turn>system\n{system_prompt}<start_of_turn>"]
for user_msg, assistant_msg in history:
prompt_parts.append(f"user\n{user_msg}<start_of_turn>")
prompt_parts.append(f"model\n{assistant_msg}<start_of_turn>")
# Add the latest user message and signal for the model to respond
prompt_parts.append(f"user\n{user_input}<start_of_turn>model")
prompt = "".join(prompt_parts)
output = pipe(prompt, **generation_args)
full_text = output[0]["generated_text"]
ai_response = full_text.split("<start_of_turn>model")[-1].strip()
# Update the history and clear the image box
history.append((user_input, ai_response))
return history, None
except Exception as e:
history.append((user_input, f"An error occurred: {str(e)}"))
print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}")
return history, None
# --- Gradio Interface (No changes needed) ---
with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation") as demo:
conversation_history = gr.State([])
gr.HTML("""
<div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
<h1>🩺 AI Symptom Consultation</h1>
<p>A conversational AI to help you understand your symptoms, powered by Google's MedGemma</p>
</div>
""")
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
<strong>⚠️ Medical Disclaimer:</strong> This is not a diagnosis. This AI is for informational purposes and is not a substitute for professional medical advice.
</div>
""")
chatbot_display = gr.Chatbot(height=500, label="Consultation")
with gr.Row():
image_input = gr.Image(label="Upload Symptom Image (Optional)", type="pil", height=150)
with gr.Column(scale=4):
user_textbox = gr.Textbox(label="Your Message", placeholder="Describe your primary symptom to begin...", lines=4)
send_button = gr.Button("Send Message", variant="primary")
def submit_message(user_input, user_image, history):
updated_history, cleared_image = handle_conversation_turn(user_input, user_image, history)
return updated_history, cleared_image
send_button.click(
fn=submit_message,
inputs=[user_textbox, image_input, conversation_history],
outputs=[chatbot_display, image_input]
).then(lambda: "", outputs=user_textbox)
clear_button = gr.Button("🗑️ Start New Consultation")
clear_button.click(lambda: ([], [], None, ""), outputs=[chatbot_display, conversation_history, image_input, user_textbox])
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch(debug=True) |