hponepyae commited on
Commit
c5882f3
·
verified ·
1 Parent(s): 4334aa5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +39 -60
app.py CHANGED
@@ -1,37 +1,32 @@
1
  import gradio as gr
2
- from transformers import AutoProcessor, AutoModelForCausalLM
3
  from PIL import Image
4
  import torch
5
  import os
6
  import spaces
7
 
8
- # --- Initialize the Model and Processor Directly ---
9
- print("Loading MedGemma model and processor...")
10
- model_id = "google/medgemma-4b-it"
11
- model = None
12
- processor = None
13
- model_loaded = False
14
-
15
  try:
16
- model = AutoModelForCausalLM.from_pretrained(
17
- model_id,
 
18
  torch_dtype=torch.bfloat16,
19
  device_map="auto",
20
  token=os.environ.get("HF_TOKEN")
21
  )
22
- processor = AutoProcessor.from_pretrained(model_id, token=os.environ.get("HF_TOKEN"))
23
  model_loaded = True
24
- print("Model and processor loaded successfully!")
25
  except Exception as e:
26
  model_loaded = False
27
  print(f"Error loading model: {e}")
28
 
29
- # --- Core Analysis Function (Final Corrected Version) ---
30
  @spaces.GPU()
31
  def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
32
  """
33
- Analyzes symptoms using the definitive two-step templating and processing method
34
- required by modern multimodal chat models.
35
  """
36
  if not model_loaded:
37
  return "Error: The AI model could not be loaded. Please check the Space logs."
@@ -41,59 +36,40 @@ def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
41
  return "Please describe your symptoms or upload an image for analysis."
42
 
43
  try:
44
- # --- STEP 1: Build the structured messages list ---
45
- system_instruction = (
46
- "You are an expert, empathetic AI medical assistant. "
47
- "Analyze the potential medical condition based on the following information. "
48
- "Provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan."
 
49
  )
50
 
51
- # The 'content' for a user's turn is a LIST of dictionaries.
52
- user_content_list = []
53
- if symptom_image:
54
- # Add a placeholder dictionary for the image.
55
- user_content_list.append({"type": "image"})
56
 
57
- # Add the dictionary for the text.
58
- text_content = f"{symptoms_text}\n\n{system_instruction}"
59
- user_content_list.append({"type": "text", "text": text_content})
60
-
61
  messages = [
62
- {"role": "user", "content": user_content_list}
 
63
  ]
64
-
65
- # --- STEP 2: Generate the prompt string using the official template ---
66
- # This will correctly create a string with all special tokens, including <image>.
67
- prompt = processor.tokenizer.apply_chat_template(
68
- messages,
69
- tokenize=False,
70
- add_generation_prompt=True
71
- )
72
-
73
- # --- STEP 3: Process the prompt string and image together ---
74
- # This is where the prompt's <image> token is linked to the actual image data.
75
- inputs = processor(
76
- text=prompt,
77
- images=symptom_image, # This can be None for text-only cases
78
- return_tensors="pt"
79
- ).to(model.device)
80
-
81
- # Generation parameters
82
- generate_kwargs = {
83
  "max_new_tokens": 512,
84
  "do_sample": True,
85
  "temperature": 0.7,
86
  }
87
-
88
- print("Generating model output with the definitive two-step process...")
89
 
90
- # Generate the response
91
- generate_ids = model.generate(**inputs, **generate_kwargs)
92
-
93
- # Decode only the newly generated tokens
94
- input_token_len = inputs["input_ids"].shape[-1]
95
- result = processor.batch_decode(generate_ids[:, input_token_len:], skip_special_tokens=True)[0]
96
 
 
 
 
97
  disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***"
98
 
99
  return result.strip() + disclaimer
@@ -102,7 +78,7 @@ def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
102
  print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
103
  return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
104
 
105
- # --- Gradio Interface (No changes needed) ---
106
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
107
  gr.HTML("""
108
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
@@ -134,10 +110,13 @@ with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
134
  label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...")
135
 
136
  def clear_all():
137
- return None, "", ""
 
138
 
 
139
  analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text)
140
- clear_btn.click(fn=clear_all, outputs=[image_input, symptoms_input, output_text])
 
141
 
142
  if __name__ == "__main__":
143
  print("Starting Gradio interface...")
 
1
  import gradio as gr
2
+ from transformers import pipeline
3
  from PIL import Image
4
  import torch
5
  import os
6
  import spaces
7
 
8
+ # --- Initialize the Model Pipeline (As per your working example) ---
9
+ print("Loading MedGemma model...")
 
 
 
 
 
10
  try:
11
+ pipe = pipeline(
12
+ "image-text-to-text",
13
+ model="google/medgemma-4b-it",
14
  torch_dtype=torch.bfloat16,
15
  device_map="auto",
16
  token=os.environ.get("HF_TOKEN")
17
  )
 
18
  model_loaded = True
19
+ print("Model loaded successfully!")
20
  except Exception as e:
21
  model_loaded = False
22
  print(f"Error loading model: {e}")
23
 
24
+ # --- Core Analysis Function (Using the logic from your working example) ---
25
  @spaces.GPU()
26
  def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
27
  """
28
+ Analyzes user's symptoms using the definitive calling convention demonstrated
29
+ in the working X-ray analyzer example.
30
  """
31
  if not model_loaded:
32
  return "Error: The AI model could not be loaded. Please check the Space logs."
 
36
  return "Please describe your symptoms or upload an image for analysis."
37
 
38
  try:
39
+ # --- DEFINITIVE MESSAGE CONSTRUCTION (from your example) ---
40
+ system_prompt = (
41
+ "You are an expert, empathetic AI medical assistant. Analyze the potential "
42
+ "medical condition based on the following information. Provide a list of "
43
+ "possible conditions, your reasoning, and a clear, actionable next-steps plan. "
44
+ "Start your analysis by describing the user-provided information."
45
  )
46
 
47
+ user_content = []
48
+ # The user's prompt text is always present.
49
+ user_content.append({"type": "text", "text": symptoms_text})
 
 
50
 
51
+ # The actual PIL image object is added to the content list if it exists.
52
+ if symptom_image:
53
+ user_content.append({"type": "image", "image": symptom_image})
54
+
55
  messages = [
56
+ {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
57
+ {"role": "user", "content": user_content}
58
  ]
59
+
60
+ generation_args = {
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
61
  "max_new_tokens": 512,
62
  "do_sample": True,
63
  "temperature": 0.7,
64
  }
 
 
65
 
66
+ # --- DEFINITIVE PIPELINE CALL (from your example) ---
67
+ # The entire messages structure is passed to the `text` argument.
68
+ output = pipe(text=messages, **generation_args)
 
 
 
69
 
70
+ # The result is the 'content' of the last generated message.
71
+ result = output[0]["generated_text"][-1]["content"]
72
+
73
  disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***"
74
 
75
  return result.strip() + disclaimer
 
78
  print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
79
  return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
80
 
81
+ # --- Gradio Interface (Your original, no changes needed) ---
82
  with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
83
  gr.HTML("""
84
  <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
 
110
  label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...")
111
 
112
  def clear_all():
113
+ # This function should return values for all outputs cleared by the button
114
+ return None, ""
115
 
116
+ # The clear button now correctly clears the image and text input.
117
  analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text)
118
+ clear_btn.click(fn=lambda: (None, "", ""), outputs=[image_input, symptoms_input, output_text])
119
+
120
 
121
  if __name__ == "__main__":
122
  print("Starting Gradio interface...")