hponepyae commited on
Commit
1fa102b
·
verified ·
1 Parent(s): 60665db

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +17 -29
app.py CHANGED
@@ -5,7 +5,7 @@ import torch
5
  import os
6
  import spaces
7
 
8
- # --- Initialize the Model Pipeline ---
9
  print("Loading MedGemma model...")
10
  try:
11
  pipe = pipeline(
@@ -25,23 +25,17 @@ except Exception as e:
25
  @spaces.GPU()
26
  def analyze_symptoms(symptom_image, symptoms_text):
27
  """
28
- Analyzes user's symptoms using the correct prompt format for MedGemma.
29
  """
30
  if not model_loaded:
31
  return "Error: The AI model could not be loaded. Please check the Space logs."
32
 
33
- # Standardize input to avoid issues with None or whitespace
34
  symptoms_text = symptoms_text.strip() if symptoms_text else ""
35
-
36
  if symptom_image is None and not symptoms_text:
37
  return "Please describe your symptoms or upload an image for analysis."
38
 
39
  try:
40
- # --- CORRECTED PROMPT LOGIC ---
41
- # MedGemma expects a specific prompt format with special tokens.
42
- # We build this prompt string dynamically.
43
-
44
- # This is the instruction part of the prompt
45
  instruction = (
46
  "You are an expert, empathetic AI medical assistant. "
47
  "Analyze the potential medical condition based on the following information. "
@@ -49,47 +43,42 @@ def analyze_symptoms(symptom_image, symptoms_text):
49
  "actionable next-steps plan. Start your analysis by describing the user-provided "
50
  "information (text and/or image)."
51
  )
52
-
53
- # Build the final prompt based on user inputs
54
  prompt_parts = ["<start_of_turn>user"]
55
  if symptoms_text:
56
  prompt_parts.append(symptoms_text)
57
-
58
- # The <image> token is a placeholder that tells the model where to "look" at the image.
59
  if symptom_image:
60
  prompt_parts.append("<image>")
61
-
62
  prompt_parts.append(instruction)
63
  prompt_parts.append("<start_of_turn>model")
64
-
65
  prompt = "\n".join(prompt_parts)
66
 
67
  print("Generating pipeline output...")
68
 
69
- # --- CORRECTED PIPELINE CALL ---
70
- # The pipeline expects the prompt string and an 'images' argument (if an image is provided).
71
- # We create a dictionary for keyword arguments to pass to the pipeline.
72
- pipeline_kwargs = {
 
 
73
  "max_new_tokens": 512,
74
  "do_sample": True,
75
  "temperature": 0.7
76
  }
77
 
78
  # The `images` argument should be a list of PIL Images.
 
79
  if symptom_image:
80
- output = pipe(prompt, images=[symptom_image], **pipeline_kwargs)
81
- else:
82
- # If no image is provided, we do not include the `images` argument in the call.
83
- output = pipe(prompt, **pipeline_kwargs)
 
84
 
85
  print("Pipeline Output:", output)
86
 
87
- # --- SIMPLIFIED OUTPUT PROCESSING ---
88
- # The pipeline returns a list with one dictionary. The result is in the 'generated_text' key.
89
  if output and isinstance(output, list) and 'generated_text' in output[0]:
90
- # We extract just the model's response part of the generated text.
91
  full_text = output[0]['generated_text']
92
- # The model output includes the prompt, so we split it to get only the new part.
93
  result = full_text.split("<start_of_turn>model\n")[-1]
94
  else:
95
  result = "The model did not return a valid response. Please try again."
@@ -100,10 +89,9 @@ def analyze_symptoms(symptom_image, symptoms_text):
100
 
101
  except Exception as e:
102
  print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
103
- # Provide a more user-friendly error message
104
  return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
105
 
106
- # --- Create the Gradio Interface (No changes needed here) ---
107
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
108
  gr.HTML("""
109
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
 
5
  import os
6
  import spaces
7
 
8
+ # --- Initialize the Model Pipeline (No changes here) ---
9
  print("Loading MedGemma model...")
10
  try:
11
  pipe = pipeline(
 
25
  @spaces.GPU()
26
  def analyze_symptoms(symptom_image, symptoms_text):
27
  """
28
+ Analyzes user's symptoms using the correct prompt format and keyword arguments for MedGemma.
29
  """
30
  if not model_loaded:
31
  return "Error: The AI model could not be loaded. Please check the Space logs."
32
 
 
33
  symptoms_text = symptoms_text.strip() if symptoms_text else ""
 
34
  if symptom_image is None and not symptoms_text:
35
  return "Please describe your symptoms or upload an image for analysis."
36
 
37
  try:
38
+ # --- PROMPT LOGIC (Unchanged) ---
 
 
 
 
39
  instruction = (
40
  "You are an expert, empathetic AI medical assistant. "
41
  "Analyze the potential medical condition based on the following information. "
 
43
  "actionable next-steps plan. Start your analysis by describing the user-provided "
44
  "information (text and/or image)."
45
  )
 
 
46
  prompt_parts = ["<start_of_turn>user"]
47
  if symptoms_text:
48
  prompt_parts.append(symptoms_text)
 
 
49
  if symptom_image:
50
  prompt_parts.append("<image>")
 
51
  prompt_parts.append(instruction)
52
  prompt_parts.append("<start_of_turn>model")
 
53
  prompt = "\n".join(prompt_parts)
54
 
55
  print("Generating pipeline output...")
56
 
57
+ # --- CORRECTED & ROBUST PIPELINE CALL ---
58
+ # We build a dictionary of all arguments to pass to the pipeline.
59
+ # This avoids the TypeError by ensuring all arguments are passed explicitly by keyword.
60
+
61
+ pipeline_args = {
62
+ "prompt": prompt,
63
  "max_new_tokens": 512,
64
  "do_sample": True,
65
  "temperature": 0.7
66
  }
67
 
68
  # The `images` argument should be a list of PIL Images.
69
+ # We only add it to our arguments dictionary if an image is provided.
70
  if symptom_image:
71
+ pipeline_args["images"] = [symptom_image]
72
+
73
+ # We use the ** syntax to unpack the dictionary into keyword arguments.
74
+ # This results in a call like: pipe(prompt=..., images=..., max_new_tokens=...)
75
+ output = pipe(**pipeline_args)
76
 
77
  print("Pipeline Output:", output)
78
 
79
+ # --- SIMPLIFIED OUTPUT PROCESSING (Unchanged) ---
 
80
  if output and isinstance(output, list) and 'generated_text' in output[0]:
 
81
  full_text = output[0]['generated_text']
 
82
  result = full_text.split("<start_of_turn>model\n")[-1]
83
  else:
84
  result = "The model did not return a valid response. Please try again."
 
89
 
90
  except Exception as e:
91
  print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
 
92
  return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
93
 
94
+ # --- Gradio Interface (No changes needed) ---
95
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
96
  gr.HTML("""
97
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">