hponepyae commited on
Commit
a91bbfc
·
verified ·
1 Parent(s): 1fa102b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +58 -44
app.py CHANGED
@@ -5,11 +5,12 @@ import torch
5
  import os
6
  import spaces
7
 
8
- # --- Initialize the Model Pipeline (No changes here) ---
9
  print("Loading MedGemma model...")
10
  try:
 
11
  pipe = pipeline(
12
- "image-text-to-text",
13
  model="google/medgemma-4b-it",
14
  torch_dtype=torch.bfloat16,
15
  device_map="auto",
@@ -21,11 +22,12 @@ except Exception as e:
21
  model_loaded = False
22
  print(f"Error loading model: {e}")
23
 
24
- # --- Core Analysis Function (Corrected) ---
25
  @spaces.GPU()
26
- def analyze_symptoms(symptom_image, symptoms_text):
27
  """
28
- Analyzes user's symptoms using the correct prompt format and keyword arguments for MedGemma.
 
29
  """
30
  if not model_loaded:
31
  return "Error: The AI model could not be loaded. Please check the Space logs."
@@ -35,51 +37,62 @@ def analyze_symptoms(symptom_image, symptoms_text):
35
  return "Please describe your symptoms or upload an image for analysis."
36
 
37
  try:
38
- # --- PROMPT LOGIC (Unchanged) ---
39
- instruction = (
 
 
40
  "You are an expert, empathetic AI medical assistant. "
41
- "Analyze the potential medical condition based on the following information. "
42
  "Provide a list of possible conditions, your reasoning, and a clear, "
43
- "actionable next-steps plan. Start your analysis by describing the user-provided "
44
- "information (text and/or image)."
45
  )
46
- prompt_parts = ["<start_of_turn>user"]
47
- if symptoms_text:
48
- prompt_parts.append(symptoms_text)
 
 
 
 
 
 
 
49
  if symptom_image:
50
- prompt_parts.append("<image>")
51
- prompt_parts.append(instruction)
52
- prompt_parts.append("<start_of_turn>model")
53
- prompt = "\n".join(prompt_parts)
54
-
55
- print("Generating pipeline output...")
56
-
57
- # --- CORRECTED & ROBUST PIPELINE CALL ---
58
- # We build a dictionary of all arguments to pass to the pipeline.
59
- # This avoids the TypeError by ensuring all arguments are passed explicitly by keyword.
60
 
61
- pipeline_args = {
62
- "prompt": prompt,
63
- "max_new_tokens": 512,
64
- "do_sample": True,
65
- "temperature": 0.7
66
- }
67
-
68
- # The `images` argument should be a list of PIL Images.
69
- # We only add it to our arguments dictionary if an image is provided.
70
- if symptom_image:
71
- pipeline_args["images"] = [symptom_image]
72
-
73
- # We use the ** syntax to unpack the dictionary into keyword arguments.
74
- # This results in a call like: pipe(prompt=..., images=..., max_new_tokens=...)
75
- output = pipe(**pipeline_args)
76
-
77
  print("Pipeline Output:", output)
78
 
79
- # --- SIMPLIFIED OUTPUT PROCESSING (Unchanged) ---
80
- if output and isinstance(output, list) and 'generated_text' in output[0]:
81
- full_text = output[0]['generated_text']
82
- result = full_text.split("<start_of_turn>model\n")[-1]
 
 
 
 
 
 
 
 
 
83
  else:
84
  result = "The model did not return a valid response. Please try again."
85
 
@@ -91,7 +104,8 @@ def analyze_symptoms(symptom_image, symptoms_text):
91
  print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
92
  return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
93
 
94
- # --- Gradio Interface (No changes needed) ---
 
95
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
96
  gr.HTML("""
97
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
 
5
  import os
6
  import spaces
7
 
8
+ # --- Initialize the Model Pipeline (No changes) ---
9
  print("Loading MedGemma model...")
10
  try:
11
+ # Using "image-to-text" is more robust for modern multimodal chat models.
12
  pipe = pipeline(
13
+ "image-to-text",
14
  model="google/medgemma-4b-it",
15
  torch_dtype=torch.bfloat16,
16
  device_map="auto",
 
22
  model_loaded = False
23
  print(f"Error loading model: {e}")
24
 
25
+ # --- Core Analysis Function (Final Corrected Version) ---
26
  @spaces.GPU()
27
+ def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
28
  """
29
+ Analyzes user's symptoms using the officially recommended chat format
30
+ for the MedGemma multimodal model.
31
  """
32
  if not model_loaded:
33
  return "Error: The AI model could not be loaded. Please check the Space logs."
 
37
  return "Please describe your symptoms or upload an image for analysis."
38
 
39
  try:
40
+ # --- DEFINITIVE CHAT-BASED PROMPT LOGIC ---
41
+
42
+ # 1. System Prompt: This sets the AI's persona and overall goal.
43
+ system_instruction = (
44
  "You are an expert, empathetic AI medical assistant. "
45
+ "Analyze the potential medical condition based on the user's input. "
46
  "Provide a list of possible conditions, your reasoning, and a clear, "
47
+ "actionable next-steps plan. Begin your analysis by describing the information "
48
+ "the user provided."
49
  )
50
+
51
+ # 2. User Content: This must be a list of dictionaries for multimodal input.
52
+ user_content = []
53
+
54
+ # The model requires some form of text. If the user provides none,
55
+ # we add a generic prompt to accompany the image.
56
+ text_to_send = symptoms_text if symptoms_text else "Please analyze this medical image."
57
+ user_content.append({"type": "text", "text": text_to_send})
58
+
59
+ # Add the image part if it exists.
60
  if symptom_image:
61
+ user_content.append({"type": "image", "image": symptom_image})
62
+
63
+ # 3. Construct the full message list for the pipeline
64
+ messages = [
65
+ {"role": "system", "content": system_instruction},
66
+ {"role": "user", "content": user_content},
67
+ ]
68
+
69
+ print("Generating pipeline output with chat format...")
 
70
 
71
+ # --- CORRECTED PIPELINE CALL ---
72
+ # Pass the `messages` list directly. The pipeline's processor, which knows
73
+ # the model's chat template, will format it correctly.
74
+ output = pipe(
75
+ messages,
76
+ max_new_tokens=512,
77
+ do_sample=True,
78
+ temperature=0.7
79
+ )
80
+
 
 
 
 
 
 
81
  print("Pipeline Output:", output)
82
 
83
+ # --- ROBUST OUTPUT PROCESSING ---
84
+ # The output from a chat-templated pipeline call is a list containing the full
85
+ # conversation history, including the newly generated assistant message.
86
+ if output and isinstance(output, list) and output[0].get('generated_text'):
87
+ # The generated_text contains the full conversation history
88
+ full_conversation = output[0]['generated_text']
89
+ # The last message in the list is the AI's response.
90
+ assistant_message = full_conversation[-1]
91
+ if assistant_message['role'] == 'assistant':
92
+ result = assistant_message['content']
93
+ else:
94
+ # Fallback in case the last message isn't from the assistant
95
+ result = str(assistant_message)
96
  else:
97
  result = "The model did not return a valid response. Please try again."
98
 
 
104
  print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
105
  return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
106
 
107
+
108
+ # --- Create the Gradio Interface (No changes needed) ---
109
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
110
  gr.HTML("""
111
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">