Spaces:
Running
Running
Update app.py
Browse filesGPU Request: Added import spaces and the @spaces.GPU decorator to symptom_checker_chat.
State Management:
The click and submit events now use chat_history as an input and output.
The symptom_checker_chat function accepts history from the state and returns the updated list to both the chatbot and chat_history.
Robust Parsing: Replaced the fragile rfind() logic with a much more reliable method of decoding only the newly generated tokens.
UI Cleanup:
Added text_box to the outputs of the event handlers.
The function now returns "" as its last value to clear the textbox after each submission.
app.py
CHANGED
@@ -3,8 +3,10 @@ import torch
|
|
3 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
4 |
from PIL import Image
|
5 |
import os
|
|
|
6 |
|
7 |
# Get the Hugging Face token from the environment variables
|
|
|
8 |
hf_token = os.environ.get("HF_TOKEN")
|
9 |
|
10 |
# Initialize the processor and model
|
@@ -16,7 +18,8 @@ model_id = "google/medgemma-4b-it"
|
|
16 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
17 |
dtype = torch.bfloat16
|
18 |
else:
|
19 |
-
|
|
|
20 |
|
21 |
model_loaded = False
|
22 |
# Load the processor and model from Hugging Face
|
@@ -34,20 +37,20 @@ except Exception as e:
|
|
34 |
print(f"Error loading model: {e}")
|
35 |
# We will display an error in the UI if the model fails to load.
|
36 |
|
37 |
-
|
38 |
# This is the core function for the chatbot
|
|
|
39 |
def symptom_checker_chat(user_input, history, image_input):
|
40 |
"""
|
41 |
Manages the conversational flow for the symptom checker.
|
42 |
"""
|
43 |
if not model_loaded:
|
44 |
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
|
45 |
-
|
|
|
46 |
|
47 |
# System prompt to guide the model's behavior
|
48 |
system_prompt = """
|
49 |
You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
|
50 |
-
|
51 |
Here is your workflow:
|
52 |
1. Analyze the user's initial input, which may include text and an image.
|
53 |
2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
|
@@ -55,7 +58,6 @@ def symptom_checker_chat(user_input, history, image_input):
|
|
55 |
4. For each possible condition, briefly explain why it might be relevant.
|
56 |
5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
|
57 |
6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
|
58 |
-
|
59 |
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
|
60 |
"""
|
61 |
|
@@ -63,15 +65,20 @@ def symptom_checker_chat(user_input, history, image_input):
|
|
63 |
conversation = [{"role": "system", "content": system_prompt}]
|
64 |
for user, assistant in history:
|
65 |
conversation.append({"role": "user", "content": user})
|
66 |
-
|
67 |
-
|
68 |
-
|
69 |
-
|
70 |
-
|
|
|
|
|
|
|
|
|
|
|
71 |
# Apply the chat template
|
72 |
prompt = processor.tokenizer.apply_chat_template(
|
73 |
-
conversation,
|
74 |
-
tokenize=False,
|
75 |
add_generation_prompt=True
|
76 |
)
|
77 |
|
@@ -85,16 +92,12 @@ def symptom_checker_chat(user_input, history, image_input):
|
|
85 |
# Generate the output from the model
|
86 |
try:
|
87 |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
88 |
-
response_text = processor.decode(outputs[0], skip_special_tokens=True)
|
89 |
|
90 |
-
#
|
91 |
-
#
|
92 |
-
|
93 |
-
|
94 |
-
|
95 |
-
clean_response = response_text[last_occurrence + len(last_assistant_marker):].strip()
|
96 |
-
else:
|
97 |
-
clean_response = "I'm sorry, I encountered an issue processing your request. Please try again."
|
98 |
|
99 |
except Exception as e:
|
100 |
print(f"Error during model generation: {e}")
|
@@ -103,8 +106,8 @@ def symptom_checker_chat(user_input, history, image_input):
|
|
103 |
# Update the history
|
104 |
history.append((user_input, clean_response))
|
105 |
|
106 |
-
|
107 |
-
|
108 |
|
109 |
# Create the Gradio Interface using Blocks for more control
|
110 |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
@@ -116,10 +119,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
116 |
)
|
117 |
|
118 |
# Chatbot component to display the conversation
|
119 |
-
chatbot = gr.Chatbot(label="Conversation", height=500)
|
120 |
|
121 |
# State to store the conversation history
|
122 |
-
chat_history = gr.State([])
|
123 |
|
124 |
with gr.Row():
|
125 |
# Image input
|
@@ -137,27 +140,30 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
|
|
137 |
|
138 |
# Function to clear all inputs
|
139 |
def clear_all():
|
140 |
-
return [], None, ""
|
141 |
|
142 |
# Clear button
|
143 |
clear_btn = gr.Button("Start New Conversation")
|
144 |
-
|
145 |
-
|
|
|
146 |
# Define what happens when the user submits
|
147 |
submit_btn.click(
|
148 |
fn=symptom_checker_chat,
|
149 |
-
inputs
|
150 |
-
|
|
|
151 |
)
|
152 |
-
|
153 |
# Define what happens when the user just presses Enter in the textbox
|
154 |
text_box.submit(
|
155 |
fn=symptom_checker_chat,
|
156 |
-
inputs
|
157 |
-
|
|
|
158 |
)
|
159 |
|
160 |
-
|
161 |
# Launch the Gradio app
|
162 |
if __name__ == "__main__":
|
163 |
demo.launch(debug=True) # Debug mode for more detailed logs
|
|
|
|
3 |
from transformers import AutoProcessor, AutoModelForCausalLM
|
4 |
from PIL import Image
|
5 |
import os
|
6 |
+
import spaces # <-- FIX 1: IMPORT SPACES
|
7 |
|
8 |
# Get the Hugging Face token from the environment variables
|
9 |
+
# Make sure to set this as a "Secret" in your Hugging Face Space settings
|
10 |
hf_token = os.environ.get("HF_TOKEN")
|
11 |
|
12 |
# Initialize the processor and model
|
|
|
18 |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
|
19 |
dtype = torch.bfloat16
|
20 |
else:
|
21 |
+
# Fallback to float16 if bfloat16 is not available
|
22 |
+
dtype = torch.float16
|
23 |
|
24 |
model_loaded = False
|
25 |
# Load the processor and model from Hugging Face
|
|
|
37 |
print(f"Error loading model: {e}")
|
38 |
# We will display an error in the UI if the model fails to load.
|
39 |
|
|
|
40 |
# This is the core function for the chatbot
|
41 |
+
@spaces.GPU # <-- FIX 1: ADD THE GPU DECORATOR
|
42 |
def symptom_checker_chat(user_input, history, image_input):
|
43 |
"""
|
44 |
Manages the conversational flow for the symptom checker.
|
45 |
"""
|
46 |
if not model_loaded:
|
47 |
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
|
48 |
+
# <-- FIX 3 & 4: Return values match new outputs
|
49 |
+
return history, history, None, ""
|
50 |
|
51 |
# System prompt to guide the model's behavior
|
52 |
system_prompt = """
|
53 |
You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
|
|
|
54 |
Here is your workflow:
|
55 |
1. Analyze the user's initial input, which may include text and an image.
|
56 |
2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
|
|
|
58 |
4. For each possible condition, briefly explain why it might be relevant.
|
59 |
5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
|
60 |
6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
|
|
|
61 |
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
|
62 |
"""
|
63 |
|
|
|
65 |
conversation = [{"role": "system", "content": system_prompt}]
|
66 |
for user, assistant in history:
|
67 |
conversation.append({"role": "user", "content": user})
|
68 |
+
if assistant: # Ensure assistant message is not None
|
69 |
+
conversation.append({"role": "assistant", "content": assistant})
|
70 |
+
|
71 |
+
# Add the current user input with a special image token if an image is present
|
72 |
+
if image_input:
|
73 |
+
# MedGemma expects the text to start with <image> token if an image is provided
|
74 |
+
conversation.append({"role": "user", "content": f"<image>\n{user_input}"})
|
75 |
+
else:
|
76 |
+
conversation.append({"role": "user", "content": user_input})
|
77 |
+
|
78 |
# Apply the chat template
|
79 |
prompt = processor.tokenizer.apply_chat_template(
|
80 |
+
conversation,
|
81 |
+
tokenize=False,
|
82 |
add_generation_prompt=True
|
83 |
)
|
84 |
|
|
|
92 |
# Generate the output from the model
|
93 |
try:
|
94 |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
|
|
|
95 |
|
96 |
+
# <-- FIX 2: ROBUST RESPONSE PARSING
|
97 |
+
# Decode only the newly generated tokens, not the whole conversation
|
98 |
+
input_token_len = inputs["input_ids"].shape[1]
|
99 |
+
generated_tokens = outputs[:, input_token_len:]
|
100 |
+
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
|
|
|
|
|
|
|
101 |
|
102 |
except Exception as e:
|
103 |
print(f"Error during model generation: {e}")
|
|
|
106 |
# Update the history
|
107 |
history.append((user_input, clean_response))
|
108 |
|
109 |
+
# <-- FIX 3 & 4: Return values to update state, clear image box, and clear text box
|
110 |
+
return history, history, None, ""
|
111 |
|
112 |
# Create the Gradio Interface using Blocks for more control
|
113 |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
|
|
|
119 |
)
|
120 |
|
121 |
# Chatbot component to display the conversation
|
122 |
+
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) # Added avatars for fun
|
123 |
|
124 |
# State to store the conversation history
|
125 |
+
chat_history = gr.State([]) # <-- FIX 3: This state will now be used correctly
|
126 |
|
127 |
with gr.Row():
|
128 |
# Image input
|
|
|
140 |
|
141 |
# Function to clear all inputs
|
142 |
def clear_all():
|
143 |
+
return [], [], None, "" # <-- FIX 3: Correctly clear the state and chatbot
|
144 |
|
145 |
# Clear button
|
146 |
clear_btn = gr.Button("Start New Conversation")
|
147 |
+
# <-- FIX 3: The outputs list now correctly targets the state
|
148 |
+
clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)
|
149 |
+
|
150 |
# Define what happens when the user submits
|
151 |
submit_btn.click(
|
152 |
fn=symptom_checker_chat,
|
153 |
+
# <-- FIX 3 & 4: Corrected inputs and outputs
|
154 |
+
inputs=[text_box, chat_history, image_box],
|
155 |
+
outputs=[chatbot, chat_history, image_box, text_box]
|
156 |
)
|
157 |
+
|
158 |
# Define what happens when the user just presses Enter in the textbox
|
159 |
text_box.submit(
|
160 |
fn=symptom_checker_chat,
|
161 |
+
# <-- FIX 3 & 4: Corrected inputs and outputs
|
162 |
+
inputs=[text_box, chat_history, image_box],
|
163 |
+
outputs=[chatbot, chat_history, image_box, text_box]
|
164 |
)
|
165 |
|
|
|
166 |
# Launch the Gradio app
|
167 |
if __name__ == "__main__":
|
168 |
demo.launch(debug=True) # Debug mode for more detailed logs
|
169 |
+
|