File size: 6,340 Bytes
83ff66a
c5882f3
b67fca4
77793f4
83ff66a
b0565c1
83ff66a
50aaa9b
c5882f3
83ff66a
c5882f3
 
 
998c789
6ef5bdf
77793f4
83ff66a
 
c5882f3
83ff66a
998c789
b67fca4
83ff66a
909352f
998c789
909352f
b67fca4
f0fe4ce
b67fca4
 
909352f
 
77793f4
 
c5882f3
909352f
 
 
 
f0fe4ce
909352f
 
60665db
d305e52
f0fe4ce
 
909352f
f0fe4ce
 
909352f
f0fe4ce
 
 
 
 
 
 
 
 
 
 
909352f
f0fe4ce
 
 
 
c5882f3
f0fe4ce
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
909352f
 
77793f4
 
909352f
 
 
 
f0fe4ce
909352f
 
998c789
 
9f24600
909352f
 
998c789
 
 
 
909352f
998c789
 
 
909352f
 
 
 
 
f0fe4ce
909352f
 
 
 
f0fe4ce
909352f
 
f0fe4ce
909352f
f0fe4ce
 
909352f
 
 
c5882f3
9c4076b
998c789
ed3187b
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces

# --- Initialize the Model Pipeline ---
print("Loading MedGemma model...")
try:
    pipe = pipeline(
        "image-text-to-text",
        model="google/medgemma-4b-it",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=os.environ.get("HF_TOKEN")
    )
    model_loaded = True
    print("Model loaded successfully!")
except Exception as e:
    model_loaded = False
    print(f"Error loading model: {e}")

# --- Core CONVERSATIONAL Logic ---
@spaces.GPU()
def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list):
    """
    Manages a single turn of the conversation, using separate logic for image and text-only inputs.
    """
    if not model_loaded:
        history.append((user_input, "Error: The AI model is not loaded. Please contact the administrator."))
        return history, None

    try:
        system_prompt = (
            "You are an expert, empathetic AI medical assistant conducting a virtual consultation. "
            "Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. "
            "Do NOT provide a diagnosis or a list of possibilities right away. "
            "Your first step is ALWAYS to ask relevant follow-up questions. Ask only one or two focused questions per turn. "
            "If the user provides an image, acknowledge it by describing what you see first, then ask your questions. "
            "After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. "
            "THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
        )

        generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7}
        ai_response = ""
        
        # --- THE FIX: We create two different logic paths ---

        if user_image:
            # --- PATH 1: Image is present. Use the proven 'messages' format. ---
            print("Image detected. Using multimodal 'messages' format...")
            
            messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
            for user_msg, assistant_msg in history:
                messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
                messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})

            latest_user_content = []
            if user_input:
                latest_user_content.append({"type": "text", "text": user_input})
            latest_user_content.append({"type": "image", "image": user_image})
            messages.append({"role": "user", "content": latest_user_content})
            
            output = pipe(text=messages, **generation_args)
            ai_response = output[0]["generated_text"][-1]["content"]

        else:
            # --- PATH 2: No image. Use a simple, robust prompt string. ---
            print("No image detected. Using robust 'text-only' format...")

            # Manually build a single string representing the entire conversation
            prompt_parts = [f"<start_of_turn>system\n{system_prompt}<start_of_turn>"]
            for user_msg, assistant_msg in history:
                prompt_parts.append(f"user\n{user_msg}<start_of_turn>")
                prompt_parts.append(f"model\n{assistant_msg}<start_of_turn>")
            
            # Add the latest user message and signal for the model to respond
            prompt_parts.append(f"user\n{user_input}<start_of_turn>model")
            prompt = "".join(prompt_parts)

            output = pipe(prompt, **generation_args)
            full_text = output[0]["generated_text"]
            ai_response = full_text.split("<start_of_turn>model")[-1].strip()

        # Update the history and clear the image box
        history.append((user_input, ai_response))
        return history, None

    except Exception as e:
        history.append((user_input, f"An error occurred: {str(e)}"))
        print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}")
        return history, None

# --- Gradio Interface (No changes needed) ---
with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation") as demo:
    conversation_history = gr.State([])

    gr.HTML("""
        <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
            <h1>🩺 AI Symptom Consultation</h1>
            <p>A conversational AI to help you understand your symptoms, powered by Google's MedGemma</p>
        </div>
    """)
    gr.HTML("""
        <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
            <strong>⚠️ Medical Disclaimer:</strong> This is not a diagnosis. This AI is for informational purposes and is not a substitute for professional medical advice.
        </div>
    """)

    chatbot_display = gr.Chatbot(height=500, label="Consultation")
    
    with gr.Row():
        image_input = gr.Image(label="Upload Symptom Image (Optional)", type="pil", height=150)
        with gr.Column(scale=4):
            user_textbox = gr.Textbox(label="Your Message", placeholder="Describe your primary symptom to begin...", lines=4)
            send_button = gr.Button("Send Message", variant="primary")

    def submit_message(user_input, user_image, history):
        updated_history, cleared_image = handle_conversation_turn(user_input, user_image, history)
        return updated_history, cleared_image

    send_button.click(
        fn=submit_message,
        inputs=[user_textbox, image_input, conversation_history],
        outputs=[chatbot_display, image_input]
    ).then(lambda: "", outputs=user_textbox)
    
    clear_button = gr.Button("🗑️ Start New Consultation")
    clear_button.click(lambda: ([], [], None, ""), outputs=[chatbot_display, conversation_history, image_input, user_textbox])

if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch(debug=True)