File size: 7,430 Bytes
83ff66a
 
b67fca4
 
83ff66a
bd084e6
83ff66a
 
bd084e6
83ff66a
 
b67fca4
 
83ff66a
 
b67fca4
 
 
83ff66a
 
bd084e6
 
83ff66a
2166c8b
b67fca4
83ff66a
b67fca4
 
83ff66a
 
 
 
 
 
 
 
b67fca4
 
83ff66a
b67fca4
bd084e6
b67fca4
 
 
 
 
 
bd084e6
 
b67fca4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
bd084e6
 
 
 
 
 
 
 
 
 
b67fca4
 
bd084e6
 
b67fca4
 
2166c8b
b67fca4
 
 
 
2166c8b
b67fca4
8ec5a02
b67fca4
 
 
 
bd084e6
 
 
 
 
b67fca4
 
 
 
 
 
 
 
bd084e6
 
83ff66a
b67fca4
 
83ff66a
 
b67fca4
 
83ff66a
 
b67fca4
 
bd084e6
b67fca4
 
bd084e6
b67fca4
83ff66a
b67fca4
 
 
 
 
 
 
 
 
83ff66a
b67fca4
 
 
 
 
bd084e6
b67fca4
 
 
bd084e6
 
 
b67fca4
 
 
bd084e6
 
 
83ff66a
bd084e6
b67fca4
 
 
bd084e6
 
 
83ff66a
 
b67fca4
83ff66a
b67fca4
bd084e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces # <-- FIX 1: IMPORT SPACES

# Get the Hugging Face token from the environment variables
# Make sure to set this as a "Secret" in your Hugging Face Space settings
hf_token = os.environ.get("HF_TOKEN")

# Initialize the processor and model
# We are using MedGemma, a 4B parameter model specialized for medical text and images.
model_id = "google/medgemma-4b-it"

# Check for GPU availability and set the data type accordingly
# Using bfloat16 for better performance on compatible GPUs.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    dtype = torch.bfloat16
else:
    # Fallback to float16 if bfloat16 is not available
    dtype = torch.float16

model_loaded = False
# Load the processor and model from Hugging Face
try:
    # AutoProcessor handles both text tokenization and image processing
    processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        torch_dtype=dtype,
        device_map="auto",
    )
    model_loaded = True
except Exception as e:
    print(f"Error loading model: {e}")
    # We will display an error in the UI if the model fails to load.

# This is the core function for the chatbot
@spaces.GPU # <-- FIX 1: ADD THE GPU DECORATOR
def symptom_checker_chat(user_input, history, image_input):
    """
    Manages the conversational flow for the symptom checker.
    """
    if not model_loaded:
        history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
        # <-- FIX 3 & 4: Return values match new outputs
        return history, history, None, ""

    # System prompt to guide the model's behavior
    system_prompt = """
    You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
    Here is your workflow:
    1.  Analyze the user's initial input, which may include text and an image.
    2.  If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
    3.  Once you have gathered enough information, provide a list of possible conditions that might align with the symptoms.
    4.  For each possible condition, briefly explain why it might be relevant.
    5.  Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
    6.  **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
    ***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
    """
    
    # Construct the conversation history for the model
    conversation = [{"role": "system", "content": system_prompt}]
    for user, assistant in history:
        conversation.append({"role": "user", "content": user})
        if assistant: # Ensure assistant message is not None
            conversation.append({"role": "assistant", "content": assistant})
            
    # Add the current user input with a special image token if an image is present
    if image_input:
        # MedGemma expects the text to start with <image> token if an image is provided
        conversation.append({"role": "user", "content": f"<image>\n{user_input}"})
    else:
        conversation.append({"role": "user", "content": user_input})
        
    # Apply the chat template
    prompt = processor.tokenizer.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=True
    )

    inputs = None
    # Process inputs, including the image if it exists
    if image_input:
        inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype)
    else:
        inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)

    # Generate the output from the model
    try:
        outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
        
        # <-- FIX 2: ROBUST RESPONSE PARSING
        # Decode only the newly generated tokens, not the whole conversation
        input_token_len = inputs["input_ids"].shape[1]
        generated_tokens = outputs[:, input_token_len:]
        clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()

    except Exception as e:
        print(f"Error during model generation: {e}")
        clean_response = "An error occurred while generating the response. Please check the logs."

    # Update the history
    history.append((user_input, clean_response))
    
    # <-- FIX 3 & 4: Return values to update state, clear image box, and clear text box
    return history, history, None, ""

# Create the Gradio Interface using Blocks for more control
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
    gr.Markdown(
        """
        # AI Symptom Checker powered by MedGemma
        Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
        """
    )
    
    # Chatbot component to display the conversation
    chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) # Added avatars for fun
    
    # State to store the conversation history
    chat_history = gr.State([]) # <-- FIX 3: This state will now be used correctly

    with gr.Row():
        # Image input
        image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")

    with gr.Row():
        # Text input
        text_box = gr.Textbox(
            label="Describe your symptoms...",
            placeholder="e.g., I have a rash on my arm. It's red and itchy.",
            scale=4, # Make the textbox wider
        )
        # Submit button
        submit_btn = gr.Button("Send", variant="primary", scale=1)

    # Function to clear all inputs
    def clear_all():
        return [], [], None, "" # <-- FIX 3: Correctly clear the state and chatbot

    # Clear button
    clear_btn = gr.Button("Start New Conversation")
    # <-- FIX 3: The outputs list now correctly targets the state
    clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)
        
    # Define what happens when the user submits
    submit_btn.click(
        fn=symptom_checker_chat,
        # <-- FIX 3 & 4: Corrected inputs and outputs
        inputs=[text_box, chat_history, image_box],
        outputs=[chatbot, chat_history, image_box, text_box]
    )
        
    # Define what happens when the user just presses Enter in the textbox
    text_box.submit(
        fn=symptom_checker_chat,
        # <-- FIX 3 & 4: Corrected inputs and outputs
        inputs=[text_box, chat_history, image_box],
        outputs=[chatbot, chat_history, image_box, text_box]
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch(debug=True) # Debug mode for more detailed logs