hponepyae's picture
Update app.py
60665db verified
raw
history blame
6.27 kB
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces
# --- Initialize the Model Pipeline ---
print("Loading MedGemma model...")
try:
pipe = pipeline(
"image-text-to-text",
model="google/medgemma-4b-it",
torch_dtype=torch.bfloat16,
device_map="auto",
token=os.environ.get("HF_TOKEN")
)
model_loaded = True
print("Model loaded successfully!")
except Exception as e:
model_loaded = False
print(f"Error loading model: {e}")
# --- Core Analysis Function (Corrected) ---
@spaces.GPU()
def analyze_symptoms(symptom_image, symptoms_text):
"""
Analyzes user's symptoms using the correct prompt format for MedGemma.
"""
if not model_loaded:
return "Error: The AI model could not be loaded. Please check the Space logs."
# Standardize input to avoid issues with None or whitespace
symptoms_text = symptoms_text.strip() if symptoms_text else ""
if symptom_image is None and not symptoms_text:
return "Please describe your symptoms or upload an image for analysis."
try:
# --- CORRECTED PROMPT LOGIC ---
# MedGemma expects a specific prompt format with special tokens.
# We build this prompt string dynamically.
# This is the instruction part of the prompt
instruction = (
"You are an expert, empathetic AI medical assistant. "
"Analyze the potential medical condition based on the following information. "
"Provide a list of possible conditions, your reasoning, and a clear, "
"actionable next-steps plan. Start your analysis by describing the user-provided "
"information (text and/or image)."
)
# Build the final prompt based on user inputs
prompt_parts = ["<start_of_turn>user"]
if symptoms_text:
prompt_parts.append(symptoms_text)
# The <image> token is a placeholder that tells the model where to "look" at the image.
if symptom_image:
prompt_parts.append("<image>")
prompt_parts.append(instruction)
prompt_parts.append("<start_of_turn>model")
prompt = "\n".join(prompt_parts)
print("Generating pipeline output...")
# --- CORRECTED PIPELINE CALL ---
# The pipeline expects the prompt string and an 'images' argument (if an image is provided).
# We create a dictionary for keyword arguments to pass to the pipeline.
pipeline_kwargs = {
"max_new_tokens": 512,
"do_sample": True,
"temperature": 0.7
}
# The `images` argument should be a list of PIL Images.
if symptom_image:
output = pipe(prompt, images=[symptom_image], **pipeline_kwargs)
else:
# If no image is provided, we do not include the `images` argument in the call.
output = pipe(prompt, **pipeline_kwargs)
print("Pipeline Output:", output)
# --- SIMPLIFIED OUTPUT PROCESSING ---
# The pipeline returns a list with one dictionary. The result is in the 'generated_text' key.
if output and isinstance(output, list) and 'generated_text' in output[0]:
# We extract just the model's response part of the generated text.
full_text = output[0]['generated_text']
# The model output includes the prompt, so we split it to get only the new part.
result = full_text.split("<start_of_turn>model\n")[-1]
else:
result = "The model did not return a valid response. Please try again."
disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***"
return result + disclaimer
except Exception as e:
print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
# Provide a more user-friendly error message
return f"An error occurred during analysis. Please check the logs for details: {str(e)}"
# --- Create the Gradio Interface (No changes needed here) ---
with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
gr.HTML("""
<div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
<h1>๐Ÿฉบ AI Symptom Analyzer</h1>
<p>Advanced symptom analysis powered by Google's MedGemma AI</p>
</div>
""")
gr.HTML("""
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
<strong>โš ๏ธ Medical Disclaimer:</strong> This AI tool is for informational purposes only and is not a substitute for professional medical diagnosis or treatment.
</div>
""")
with gr.Row(equal_height=True):
with gr.Column(scale=1):
gr.Markdown("### 1. Describe Your Symptoms")
symptoms_input = gr.Textbox(
label="Symptoms",
placeholder="e.g., 'I have a rash on my arm that is red and itchy...'", lines=5)
gr.Markdown("### 2. Upload an Image (Optional)")
image_input = gr.Image(label="Symptom Image", type="pil", height=300)
with gr.Row():
clear_btn = gr.Button("๐Ÿ—‘๏ธ Clear All", variant="secondary")
analyze_btn = gr.Button("๐Ÿ” Analyze Symptoms", variant="primary", size="lg")
with gr.Column(scale=1):
gr.Markdown("### ๐Ÿ“Š Analysis Report")
output_text = gr.Textbox(
label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...")
def clear_all():
return None, "", ""
analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text)
clear_btn.click(fn=clear_all, outputs=[image_input, symptoms_input, output_text])
if __name__ == "__main__":
print("Starting Gradio interface...")
demo.launch(debug=True)