File size: 7,979 Bytes
83ff66a
 
b67fca4
 
83ff66a
b0565c1
83ff66a
b0565c1
 
83ff66a
 
 
b0565c1
b67fca4
 
83ff66a
 
b0565c1
83ff66a
2166c8b
83ff66a
b67fca4
 
83ff66a
 
 
 
b0565c1
83ff66a
 
b0565c1
83ff66a
b67fca4
b0565c1
83ff66a
b0565c1
 
a2c1346
b0565c1
b67fca4
b0565c1
 
 
 
 
 
 
 
 
 
b67fca4
 
b0565c1
 
b67fca4
 
b0565c1
 
 
 
 
 
 
 
b67fca4
 
b0565c1
 
 
 
a2c1346
 
 
b0565c1
b67fca4
b0565c1
a2c1346
 
 
b0565c1
 
a2c1346
bd084e6
b0565c1
 
a2c1346
bd084e6
 
b67fca4
 
2166c8b
b0565c1
b67fca4
b0565c1
 
 
 
 
 
b67fca4
b0565c1
 
bd084e6
 
 
b67fca4
 
 
 
 
b0565c1
 
b67fca4
b0565c1
a2c1346
b0565c1
a2c1346
 
b67fca4
b0565c1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ff66a
b0565c1
83ff66a
b0565c1
 
bd084e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces

# --- Configuration ---
# Get the Hugging Face token from the environment variables (set as a Secret in your Space)
hf_token = os.environ.get("HF_TOKEN")
model_id = "google/medgemma-4b-it"

# --- Model Loading ---
# Check for GPU availability and set the data type accordingly
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    dtype = torch.bfloat16
else:
    dtype = torch.float16 # Fallback to float16

model_loaded = False
try:
    # AutoProcessor handles both text tokenization and image processing
    processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        torch_dtype=dtype,
        device_map="auto", # Automatically use available GPU(s)
    )
    model_loaded = True
    print("Model loaded successfully on device:", model.device)
except Exception as e:
    print(f"Error loading model: {e}")
    # An error will be displayed in the UI if the model fails to load.

# --- Core Chatbot Function ---
# Add the GPU decorator to signal to Hugging Face Spaces that we need a GPU
@spaces.GPU
def symptom_checker_chat(user_input, history_state, image_input):
    """
    Manages the conversational flow, handling both text and images.
    
    Args:
        user_input (str): The text input from the user.
        history_state (list): The stateful conversation history.
        image_input (PIL.Image): The uploaded image, if any.
        
    Returns:
        tuple: A tuple containing the updated UI history, the updated state history,
               a None value to clear the image box, and an empty string to clear the text box.
    """
    if not model_loaded:
        history_state.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
        return history_state, history_state, None, ""

    system_prompt = """
    You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
    Here is your workflow:
    1.  Analyze the user's initial input, which may include text and an image. If an image is provided, your first step is to describe what you see in the image relevant to the user's query.
    2.  If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
    3.  Once you have gathered enough information, provide a list of possible conditions that might align with the symptoms.
    4.  For each possible condition, briefly explain why it might be relevant.
    5.  Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
    6.  **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
    ***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
    """

    # Prepare the input for the model, including the special <image> token if an image is provided.
    # This is what the model sees and what we save in our state to maintain context.
    model_input_text = user_input
    if image_input:
        model_input_text = f"<image>\n{user_input}"

    # Construct the full conversation history for the model
    conversation = [{"role": "system", "content": system_prompt}]
    for turn_input, assistant_output in history_state:
        conversation.append({"role": "user", "content": turn_input})
        if assistant_output:
            conversation.append({"role": "assistant", "content": assistant_output})
    
    # Add the current user's turn
    conversation.append({"role": "user", "content": model_input_text})
        
    # FIX: Use the main `processor.apply_chat_template` which is multimodal-aware.
    # This correctly prepares the prompt for models that handle both text and images.
    prompt = processor.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=True
    )

    # Process the inputs (text and optional image)
    try:
        if image_input:
            inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype)
        else:
            inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)

        # Generate the response
        outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
        
        # FIX: Decode only the newly generated tokens for a clean response.
        input_token_len = inputs["input_ids"].shape[1]
        generated_tokens = outputs[:, input_token_len:]
        clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()

    except Exception as e:
        print(f"Error during model generation: {e}")
        clean_response = "An error occurred while generating the response. Please check the logs."

    # FIX: Update the stateful history with the model-aware input text.
    history_state.append((model_input_text, clean_response))
    
    # Create a separate history for UI display, stripping the <image> token.
    display_history = []
    for turn_input, assistant_output in history_state:
        display_input = turn_input.replace("<image>\n", "")
        display_history.append((display_input, assistant_output))
    
    # Return the UI history, the state history, and values to clear the inputs.
    return display_history, history_state, None, ""

# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
    gr.Markdown(
        """
        # AI Symptom Checker powered by MedGemma
        Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
        """
    )
    
    # The chatbot component for displaying the conversation
    chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
    
    # The state component to store the conversation history (including model-specific tokens)
    chat_history = gr.State([])

    with gr.Row():
        image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")

    with gr.Row():
        text_box = gr.Textbox(
            label="Describe your symptoms...",
            placeholder="e.g., I have a rash on my arm. It's red and itchy.",
            scale=4,
        )
        submit_btn = gr.Button("Send", variant="primary", scale=1)

    # Define the function to clear all inputs and memory
    def clear_all():
        return [], [], None, ""

    clear_btn = gr.Button("Start New Conversation")
    clear_btn.click(
        fn=clear_all, 
        outputs=[chatbot, chat_history, image_box, text_box], 
        queue=False
    )
        
    # Define what happens when the user clicks "Send"
    submit_btn.click(
        fn=symptom_checker_chat,
        # FIX: Use the state component for inputs and outputs
        inputs=[text_box, chat_history, image_box],
        # FIX: The function returns values for each of these components
        outputs=[chatbot, chat_history, image_box, text_box]
    )
        
    # Define what happens when the user presses Enter in the textbox
    text_box.submit(
        fn=symptom_checker_chat,
        inputs=[text_box, chat_history, image_box],
        outputs=[chatbot, chat_history, image_box, text_box]
    )

# --- Launch the App ---
if __name__ == "__main__":
    demo.launch(debug=True) # Debug mode provides more detailed logs in the console