hponepyae commited on
Commit
909352f
·
verified ·
1 Parent(s): 8a813aa

Changed to conversation style

Browse files
Files changed (1) hide show
  1. app.py +90 -86
app.py CHANGED
@@ -8,6 +8,7 @@ import spaces
8
  # --- Initialize the Model Pipeline ---
9
  print("Loading MedGemma model...")
10
  try:
 
11
  pipe = pipeline(
12
  "image-text-to-text",
13
  model="google/medgemma-4b-it",
@@ -21,115 +22,118 @@ except Exception as e:
21
  model_loaded = False
22
  print(f"Error loading model: {e}")
23
 
24
- # --- Core Analysis Function (Final Robust Version) ---
25
  @spaces.GPU()
26
- def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
27
  """
28
- Analyzes user's symptoms with separate, robust logic for image and text-only inputs.
29
  """
30
  if not model_loaded:
31
- return "Error: The AI model could not be loaded. Please check the Space logs."
32
-
33
- symptoms_text = symptoms_text.strip() if symptoms_text else ""
34
- if symptom_image is None and not symptoms_text:
35
- return "Please describe your symptoms or upload an image for analysis."
36
 
37
  try:
 
 
38
  system_prompt = (
39
- "You are an expert, empathetic AI medical assistant. Analyze the potential "
40
- "medical condition based on the following information. Provide a list of "
41
- "possible conditions, your reasoning, and a clear, actionable next-steps plan. "
42
- "Start your analysis by describing the user-provided information."
 
 
 
43
  )
44
 
45
- generation_args = {
46
- "max_new_tokens": 1024,
47
- "do_sample": True,
48
- "temperature": 0.7,
49
- }
50
-
51
- result = ""
52
-
53
- # --- THE FIX: Create two different paths for the logic ---
54
-
55
- if symptom_image:
56
- # --- PATH 1: Image is present. Use the proven 'messages' format. ---
57
- print("Image detected. Using multimodal 'messages' format...")
58
- user_content = []
59
-
60
- # Only add text content if it actually exists.
61
- if symptoms_text:
62
- user_content.append({"type": "text", "text": symptoms_text})
63
-
64
- user_content.append({"type": "image", "image": symptom_image})
65
-
66
- messages = [
67
- {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
68
- {"role": "user", "content": user_content}
69
- ]
70
-
71
- output = pipe(text=messages, **generation_args)
72
- result = output[0]["generated_text"][-1]["content"]
73
-
74
- else:
75
- # --- PATH 2: No image. Use a simple, robust prompt string for text-only. ---
76
- print("No image detected. Using robust 'text-only' format...")
77
-
78
- # Manually construct the prompt to ensure correct formatting.
79
- prompt = (
80
- f"<start_of_turn>system\n{system_prompt}<start_of_turn>user\n"
81
- f"{symptoms_text}<start_of_turn>model\n"
82
- )
83
-
84
- # A simple string call returns a different output format.
85
- output = pipe(prompt, **generation_args)
86
-
87
- # The full generated text includes the prompt, so we must split it off.
88
- full_text = output[0]["generated_text"]
89
- result = full_text.split("<start_of_turn>model\n")[-1]
90
 
91
- disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***"
 
 
92
 
93
- return result.strip() + disclaimer
 
94
 
95
  except Exception as e:
96
- print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
97
- return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
 
 
 
 
 
 
98
 
99
- # --- Gradio Interface (No changes needed) ---
100
- with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
101
  gr.HTML("""
102
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
103
- <h1>🩺 AI Symptom Analyzer</h1>
104
- <p>Advanced symptom analysis powered by Google's MedGemma AI</p>
105
  </div>
106
  """)
107
  gr.HTML("""
108
  <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
109
- <strong>⚠️ Medical Disclaimer:</strong> This AI tool is for informational purposes only and is not a substitute for professional medical diagnosis or treatment.
110
  </div>
111
  """)
112
-
113
- with gr.Row(equal_height=True):
114
- with gr.Column(scale=1):
115
- gr.Markdown("### 1. Describe Your Symptoms")
116
- symptoms_input = gr.Textbox(
117
- label="Symptoms",
118
- placeholder="e.g., 'I have a rash on my arm that is red and itchy...'", lines=5)
119
- gr.Markdown("### 2. Upload an Image (Optional)")
120
- image_input = gr.Image(label="Symptom Image", type="pil", height=300)
121
- with gr.Row():
122
- clear_btn = gr.Button("🗑️ Clear All", variant="secondary")
123
- analyze_btn = gr.Button("🔍 Analyze Symptoms", variant="primary", size="lg")
124
-
125
- with gr.Column(scale=1):
126
- gr.Markdown("### 📊 Analysis Report")
127
- output_text = gr.Textbox(
128
- label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...")
129
 
130
- # Event handlers
131
- analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text)
132
- clear_btn.click(fn=lambda: (None, "", ""), outputs=[image_input, symptoms_input, output_text])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
  if __name__ == "__main__":
135
  print("Starting Gradio interface...")
 
8
  # --- Initialize the Model Pipeline ---
9
  print("Loading MedGemma model...")
10
  try:
11
+ # We use the same pipeline, but our interaction with it will be different.
12
  pipe = pipeline(
13
  "image-text-to-text",
14
  model="google/medgemma-4b-it",
 
22
  model_loaded = False
23
  print(f"Error loading model: {e}")
24
 
25
+ # --- Core CONVERSATIONAL Logic ---
26
  @spaces.GPU()
27
+ def handle_conversation_turn(user_input: str, user_image: Image.Image, history: list):
28
  """
29
+ Manages a single turn of the conversation, maintaining history.
30
  """
31
  if not model_loaded:
32
+ # Append an error message to the chatbot history
33
+ history.append((user_input, "Error: The AI model is not loaded. Please contact the administrator."))
34
+ return history, None
 
 
35
 
36
  try:
37
+ # --- 1. Define the AI's persona and instructions ---
38
+ # This is the most critical part for controlling the conversational flow.
39
  system_prompt = (
40
+ "You are an expert, empathetic AI medical assistant conducting a virtual consultation. "
41
+ "Your primary goal is to ask clarifying questions to understand the user's symptoms thoroughly. "
42
+ "Do NOT provide a diagnosis or a list of possibilities right away. "
43
+ "Your first step is ALWAYS to ask relevant follow-up questions. Ask only one or two focused questions per turn. "
44
+ "If the user provides an image, acknowledge it by describing what you see in the image first, then ask your questions. "
45
+ "After several turns of asking questions, when you feel you have gathered enough information, you must FIRST state that you are ready to provide a summary. "
46
+ "THEN, in the SAME response, provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
47
  )
48
 
49
+ # --- 2. Format the conversation for the model ---
50
+ # The history needs to be converted into the format the model expects.
51
+ messages = [{"role": "system", "content": [{"type": "text", "text": system_prompt}]}]
52
+
53
+ # Add past interactions from the history
54
+ for user_msg, assistant_msg in history:
55
+ messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]})
56
+ messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]})
57
+
58
+ # Add the LATEST user input, including the image if provided
59
+ latest_user_content = []
60
+ if user_input:
61
+ latest_user_content.append({"type": "text", "text": user_input})
62
+ if user_image:
63
+ latest_user_content.append({"type": "image", "image": user_image})
64
+
65
+ messages.append({"role": "user", "content": latest_user_content})
66
+
67
+ # --- 3. Call the pipeline ---
68
+ generation_args = {"max_new_tokens": 1024, "do_sample": True, "temperature": 0.7}
69
+
70
+ output = pipe(text=messages, **generation_args)
71
+ ai_response = output[0]["generated_text"][-1]["content"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
 
73
+ # --- 4. Update the history ---
74
+ # We store the user's text input and the AI's response. The image is "consumed" in the turn.
75
+ history.append((user_input, ai_response))
76
 
77
+ # We return the updated history for the chatbot display and None to clear the image box.
78
+ return history, None
79
 
80
  except Exception as e:
81
+ history.append((user_input, f"An error occurred: {str(e)}"))
82
+ print(f"An exception occurred during conversation turn: {type(e).__name__}: {e}")
83
+ return history, None
84
+
85
+ # --- Gradio Interface for Conversational Flow ---
86
+ with gr.Blocks(theme=gr.themes.Soft(), title="AI Doctor Consultation") as demo:
87
+ # We use gr.State to hold the conversation history as a list of tuples.
88
+ conversation_history = gr.State([])
89
 
 
 
90
  gr.HTML("""
91
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
92
+ <h1>🩺 AI Symptom Consultation</h1>
93
+ <p>A conversational AI to help you understand your symptoms, powered by Google's MedGemma</p>
94
  </div>
95
  """)
96
  gr.HTML("""
97
  <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
98
+ <strong>⚠️ Medical Disclaimer:</strong> This is not a diagnosis. This AI is for informational purposes and is not a substitute for professional medical advice.
99
  </div>
100
  """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
101
 
102
+ # The chatbot component will display the conversation history.
103
+ chatbot_display = gr.Chatbot(height=500, label="Consultation")
104
+
105
+ with gr.Row():
106
+ # We need an image box that can be cleared after each turn.
107
+ image_input = gr.Image(label="Upload Symptom Image (Optional)", type="pil", height=150)
108
+
109
+ with gr.Column(scale=4):
110
+ # The textbox for the user to type their message.
111
+ user_textbox = gr.Textbox(
112
+ label="Your Message",
113
+ placeholder="Describe your primary symptom to begin...",
114
+ lines=4,
115
+ )
116
+ send_button = gr.Button("Send Message", variant="primary")
117
+
118
+ def submit_message(user_input, user_image, history):
119
+ # This wrapper calls the main logic and then clears the user's input fields.
120
+ updated_history, cleared_image = handle_conversation_turn(user_input, user_image, history)
121
+ return updated_history, conversation_history, "", cleared_image
122
+
123
+ # The submit action
124
+ send_button.click(
125
+ fn=handle_conversation_turn,
126
+ inputs=[user_textbox, image_input, conversation_history],
127
+ outputs=[chatbot_display, image_input] # Update the chatbot and clear the image
128
+ ).then(
129
+ # Clear the user's text box after the message is sent.
130
+ lambda: "",
131
+ outputs=user_textbox
132
+ )
133
+
134
+ # Add a clear button for convenience
135
+ clear_button = gr.Button("🗑️ Start New Consultation")
136
+ clear_button.click(lambda: ([], [], None, ""), outputs=[chatbot_display, conversation_history, image_input, user_textbox])
137
 
138
  if __name__ == "__main__":
139
  print("Starting Gradio interface...")