hponepyae commited on
Commit
9f24600
·
verified ·
1 Parent(s): d305e52

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +26 -29
app.py CHANGED
@@ -13,7 +13,6 @@ processor = None
13
  model_loaded = False
14
 
15
  try:
16
- # We load the model and its dedicated processor separately.
17
  model = AutoModelForCausalLM.from_pretrained(
18
  model_id,
19
  torch_dtype=torch.bfloat16,
@@ -31,8 +30,8 @@ except Exception as e:
31
  @spaces.GPU()
32
  def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
33
  """
34
- Analyzes symptoms by directly using the model and processor with the correct,
35
- two-step templating and processing logic.
36
  """
37
  if not model_loaded:
38
  return "Error: The AI model could not be loaded. Please check the Space logs."
@@ -42,53 +41,51 @@ def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
42
  return "Please describe your symptoms or upload an image for analysis."
43
 
44
  try:
45
- # --- DEFINITIVE PROMPT & INPUT PREPARATION ---
46
-
47
- # 1. Combine all text inputs into a single string for the user's turn.
48
- # Add the <image> placeholder only if an image is provided.
49
  system_instruction = (
50
  "You are an expert, empathetic AI medical assistant. "
51
  "Analyze the potential medical condition based on the following information. "
52
  "Provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
53
  )
54
-
55
- user_content = ""
56
- if symptom_image:
57
- # The model expects the <image> token to know where to place the image.
58
- user_content += "<image>\n"
59
-
60
- # Combine user text and system instructions for the user's message.
61
- user_content += f"{symptoms_text}\n\n{system_instruction}"
62
 
63
- messages = [
64
- {"role": "user", "content": user_content}
65
- ]
66
 
67
- # 2. Use the tokenizer to apply the model's specific chat template.
68
- # This correctly formats the text with all required special tokens.
69
- prompt = processor.tokenizer.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
 
 
 
 
 
 
 
 
 
70
 
71
- # 3. Use the main processor to prepare the final model inputs.
72
- # It takes the formatted text and the PIL image and creates the tensors.
73
  inputs = processor(
74
  text=prompt,
75
- images=symptom_image, # This can be None if no image is provided
76
  return_tensors="pt"
77
  ).to(model.device)
78
 
79
- # 4. Generation parameters
80
  generate_kwargs = {
81
  "max_new_tokens": 512,
82
  "do_sample": True,
83
  "temperature": 0.7,
84
  }
85
 
86
- print("Generating model output directly...")
87
 
88
- # 5. Generate the response
89
  generate_ids = model.generate(**inputs, **generate_kwargs)
90
 
91
- # 6. Decode only the newly generated tokens back into a string.
92
  input_token_len = inputs["input_ids"].shape[-1]
93
  result = processor.batch_decode(generate_ids[:, input_token_len:], skip_special_tokens=True)[0]
94
 
@@ -103,7 +100,7 @@ def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
103
  # --- Gradio Interface (No changes needed) ---
104
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
105
  gr.HTML("""
106
- <div style="text-align: center; background: linear-gradient(135deg, #66eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
107
  <h1>🩺 AI Symptom Analyzer</h1>
108
  <p>Advanced symptom analysis powered by Google's MedGemma AI</p>
109
  </div>
 
13
  model_loaded = False
14
 
15
  try:
 
16
  model = AutoModelForCausalLM.from_pretrained(
17
  model_id,
18
  torch_dtype=torch.bfloat16,
 
30
  @spaces.GPU()
31
  def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
32
  """
33
+ Analyzes symptoms by MANUALLY constructing the prompt string to ensure all special
34
+ tokens are correctly placed, bypassing the faulty chat template abstraction.
35
  """
36
  if not model_loaded:
37
  return "Error: The AI model could not be loaded. Please check the Space logs."
 
41
  return "Please describe your symptoms or upload an image for analysis."
42
 
43
  try:
44
+ # --- DEFINITIVE MANUAL PROMPT CONSTRUCTION ---
45
+
 
 
46
  system_instruction = (
47
  "You are an expert, empathetic AI medical assistant. "
48
  "Analyze the potential medical condition based on the following information. "
49
  "Provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
50
  )
 
 
 
 
 
 
 
 
51
 
52
+ # 1. Manually build the prompt string as a list of parts.
53
+ prompt_parts = ["<start_of_turn>user"]
 
54
 
55
+ # 2. CRUCIAL: Add the <image> placeholder *only* if an image exists.
56
+ if symptom_image:
57
+ prompt_parts.append("<image>")
58
+
59
+ # 3. Add all text content.
60
+ prompt_parts.append(f"{symptoms_text}\n\n{system_instruction}")
61
+
62
+ # 4. Signal the start of the model's turn.
63
+ prompt_parts.append("<start_of_turn>model")
64
+
65
+ # 5. Join all parts into a single string. This is our final prompt.
66
+ prompt = "\n".join(prompt_parts)
67
 
68
+ # 6. Use the processor with our manually built prompt. It will now find the <image>
69
+ # token and correctly process the associated image object.
70
  inputs = processor(
71
  text=prompt,
72
+ images=symptom_image, # This will be None for text-only, which is now handled correctly.
73
  return_tensors="pt"
74
  ).to(model.device)
75
 
76
+ # 7. Generation parameters
77
  generate_kwargs = {
78
  "max_new_tokens": 512,
79
  "do_sample": True,
80
  "temperature": 0.7,
81
  }
82
 
83
+ print("Generating model output with manually constructed prompt...")
84
 
85
+ # 8. Generate the response
86
  generate_ids = model.generate(**inputs, **generate_kwargs)
87
 
88
+ # 9. Decode only the newly generated tokens. This logic is correct.
89
  input_token_len = inputs["input_ids"].shape[-1]
90
  result = processor.batch_decode(generate_ids[:, input_token_len:], skip_special_tokens=True)[0]
91
 
 
100
  # --- Gradio Interface (No changes needed) ---
101
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
102
  gr.HTML("""
103
+ <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
104
  <h1>🩺 AI Symptom Analyzer</h1>
105
  <p>Advanced symptom analysis powered by Google's MedGemma AI</p>
106
  </div>