hponepyae commited on
Commit
bd084e6
·
verified ·
1 Parent(s): b67fca4

Update app.py

Browse files

GPU Request: Added import spaces and the @spaces.GPU decorator to symptom_checker_chat.
State Management:
The click and submit events now use chat_history as an input and output.
The symptom_checker_chat function accepts history from the state and returns the updated list to both the chatbot and chat_history.
Robust Parsing: Replaced the fragile rfind() logic with a much more reliable method of decoding only the newly generated tokens.
UI Cleanup:
Added text_box to the outputs of the event handlers.
The function now returns "" as its last value to clear the textbox after each submission.

Files changed (1) hide show
  1. app.py +40 -34
app.py CHANGED
@@ -3,8 +3,10 @@ import torch
3
  from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
  import os
 
6
 
7
  # Get the Hugging Face token from the environment variables
 
8
  hf_token = os.environ.get("HF_TOKEN")
9
 
10
  # Initialize the processor and model
@@ -16,7 +18,8 @@ model_id = "google/medgemma-4b-it"
16
  if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
17
  dtype = torch.bfloat16
18
  else:
19
- dtype = torch.float32
 
20
 
21
  model_loaded = False
22
  # Load the processor and model from Hugging Face
@@ -34,20 +37,20 @@ except Exception as e:
34
  print(f"Error loading model: {e}")
35
  # We will display an error in the UI if the model fails to load.
36
 
37
-
38
  # This is the core function for the chatbot
 
39
  def symptom_checker_chat(user_input, history, image_input):
40
  """
41
  Manages the conversational flow for the symptom checker.
42
  """
43
  if not model_loaded:
44
  history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
45
- return history, None
 
46
 
47
  # System prompt to guide the model's behavior
48
  system_prompt = """
49
  You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
50
-
51
  Here is your workflow:
52
  1. Analyze the user's initial input, which may include text and an image.
53
  2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
@@ -55,7 +58,6 @@ def symptom_checker_chat(user_input, history, image_input):
55
  4. For each possible condition, briefly explain why it might be relevant.
56
  5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
57
  6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
58
-
59
  ***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
60
  """
61
 
@@ -63,15 +65,20 @@ def symptom_checker_chat(user_input, history, image_input):
63
  conversation = [{"role": "system", "content": system_prompt}]
64
  for user, assistant in history:
65
  conversation.append({"role": "user", "content": user})
66
- conversation.append({"role": "assistant", "content": assistant})
67
-
68
- # Add the current user input
69
- conversation.append({"role": "user", "content": user_input})
70
-
 
 
 
 
 
71
  # Apply the chat template
72
  prompt = processor.tokenizer.apply_chat_template(
73
- conversation,
74
- tokenize=False,
75
  add_generation_prompt=True
76
  )
77
 
@@ -85,16 +92,12 @@ def symptom_checker_chat(user_input, history, image_input):
85
  # Generate the output from the model
86
  try:
87
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
88
- response_text = processor.decode(outputs[0], skip_special_tokens=True)
89
 
90
- # Clean the response to only get the assistant's part
91
- # This logic finds the last assistant message in the generated text
92
- last_assistant_marker = "assistant\n"
93
- last_occurrence = response_text.rfind(last_assistant_marker)
94
- if last_occurrence != -1:
95
- clean_response = response_text[last_occurrence + len(last_assistant_marker):].strip()
96
- else:
97
- clean_response = "I'm sorry, I encountered an issue processing your request. Please try again."
98
 
99
  except Exception as e:
100
  print(f"Error during model generation: {e}")
@@ -103,8 +106,8 @@ def symptom_checker_chat(user_input, history, image_input):
103
  # Update the history
104
  history.append((user_input, clean_response))
105
 
106
- return history, None # Return updated history and clear the image input
107
-
108
 
109
  # Create the Gradio Interface using Blocks for more control
110
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
@@ -116,10 +119,10 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
116
  )
117
 
118
  # Chatbot component to display the conversation
119
- chatbot = gr.Chatbot(label="Conversation", height=500)
120
 
121
  # State to store the conversation history
122
- chat_history = gr.State([])
123
 
124
  with gr.Row():
125
  # Image input
@@ -137,27 +140,30 @@ with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}")
137
 
138
  # Function to clear all inputs
139
  def clear_all():
140
- return [], None, "", None
141
 
142
  # Clear button
143
  clear_btn = gr.Button("Start New Conversation")
144
- clear_btn.click(clear_all, outputs=[chatbot, image_box, text_box, chat_history], queue=False)
145
-
 
146
  # Define what happens when the user submits
147
  submit_btn.click(
148
  fn=symptom_checker_chat,
149
- inputs=[text_box, chatbot, image_box],
150
- outputs=[chatbot, image_box] # Also clear the image box after submission
 
151
  )
152
-
153
  # Define what happens when the user just presses Enter in the textbox
154
  text_box.submit(
155
  fn=symptom_checker_chat,
156
- inputs=[text_box, chatbot, image_box],
157
- outputs=[chatbot, image_box]
 
158
  )
159
 
160
-
161
  # Launch the Gradio app
162
  if __name__ == "__main__":
163
  demo.launch(debug=True) # Debug mode for more detailed logs
 
 
3
  from transformers import AutoProcessor, AutoModelForCausalLM
4
  from PIL import Image
5
  import os
6
+ import spaces # <-- FIX 1: IMPORT SPACES
7
 
8
  # Get the Hugging Face token from the environment variables
9
+ # Make sure to set this as a "Secret" in your Hugging Face Space settings
10
  hf_token = os.environ.get("HF_TOKEN")
11
 
12
  # Initialize the processor and model
 
18
  if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
19
  dtype = torch.bfloat16
20
  else:
21
+ # Fallback to float16 if bfloat16 is not available
22
+ dtype = torch.float16
23
 
24
  model_loaded = False
25
  # Load the processor and model from Hugging Face
 
37
  print(f"Error loading model: {e}")
38
  # We will display an error in the UI if the model fails to load.
39
 
 
40
  # This is the core function for the chatbot
41
+ @spaces.GPU # <-- FIX 1: ADD THE GPU DECORATOR
42
  def symptom_checker_chat(user_input, history, image_input):
43
  """
44
  Manages the conversational flow for the symptom checker.
45
  """
46
  if not model_loaded:
47
  history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
48
+ # <-- FIX 3 & 4: Return values match new outputs
49
+ return history, history, None, ""
50
 
51
  # System prompt to guide the model's behavior
52
  system_prompt = """
53
  You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
 
54
  Here is your workflow:
55
  1. Analyze the user's initial input, which may include text and an image.
56
  2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
 
58
  4. For each possible condition, briefly explain why it might be relevant.
59
  5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
60
  6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
 
61
  ***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
62
  """
63
 
 
65
  conversation = [{"role": "system", "content": system_prompt}]
66
  for user, assistant in history:
67
  conversation.append({"role": "user", "content": user})
68
+ if assistant: # Ensure assistant message is not None
69
+ conversation.append({"role": "assistant", "content": assistant})
70
+
71
+ # Add the current user input with a special image token if an image is present
72
+ if image_input:
73
+ # MedGemma expects the text to start with <image> token if an image is provided
74
+ conversation.append({"role": "user", "content": f"<image>\n{user_input}"})
75
+ else:
76
+ conversation.append({"role": "user", "content": user_input})
77
+
78
  # Apply the chat template
79
  prompt = processor.tokenizer.apply_chat_template(
80
+ conversation,
81
+ tokenize=False,
82
  add_generation_prompt=True
83
  )
84
 
 
92
  # Generate the output from the model
93
  try:
94
  outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
 
95
 
96
+ # <-- FIX 2: ROBUST RESPONSE PARSING
97
+ # Decode only the newly generated tokens, not the whole conversation
98
+ input_token_len = inputs["input_ids"].shape[1]
99
+ generated_tokens = outputs[:, input_token_len:]
100
+ clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
 
 
 
101
 
102
  except Exception as e:
103
  print(f"Error during model generation: {e}")
 
106
  # Update the history
107
  history.append((user_input, clean_response))
108
 
109
+ # <-- FIX 3 & 4: Return values to update state, clear image box, and clear text box
110
+ return history, history, None, ""
111
 
112
  # Create the Gradio Interface using Blocks for more control
113
  with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
 
119
  )
120
 
121
  # Chatbot component to display the conversation
122
+ chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) # Added avatars for fun
123
 
124
  # State to store the conversation history
125
+ chat_history = gr.State([]) # <-- FIX 3: This state will now be used correctly
126
 
127
  with gr.Row():
128
  # Image input
 
140
 
141
  # Function to clear all inputs
142
  def clear_all():
143
+ return [], [], None, "" # <-- FIX 3: Correctly clear the state and chatbot
144
 
145
  # Clear button
146
  clear_btn = gr.Button("Start New Conversation")
147
+ # <-- FIX 3: The outputs list now correctly targets the state
148
+ clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)
149
+
150
  # Define what happens when the user submits
151
  submit_btn.click(
152
  fn=symptom_checker_chat,
153
+ # <-- FIX 3 & 4: Corrected inputs and outputs
154
+ inputs=[text_box, chat_history, image_box],
155
+ outputs=[chatbot, chat_history, image_box, text_box]
156
  )
157
+
158
  # Define what happens when the user just presses Enter in the textbox
159
  text_box.submit(
160
  fn=symptom_checker_chat,
161
+ # <-- FIX 3 & 4: Corrected inputs and outputs
162
+ inputs=[text_box, chat_history, image_box],
163
+ outputs=[chatbot, chat_history, image_box, text_box]
164
  )
165
 
 
166
  # Launch the Gradio app
167
  if __name__ == "__main__":
168
  demo.launch(debug=True) # Debug mode for more detailed logs
169
+