Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
from PIL import Image | |
import torch | |
import os | |
import spaces | |
# --- Initialize the Model Pipeline --- | |
print("Loading MedGemma model...") | |
try: | |
pipe = pipeline( | |
"image-text-to-text", | |
model="google/medgemma-4b-it", | |
torch_dtype=torch.bfloat16, | |
device_map="auto", | |
token=os.environ.get("HF_TOKEN") | |
) | |
model_loaded = True | |
print("Model loaded successfully!") | |
except Exception as e: | |
model_loaded = False | |
print(f"Error loading model: {e}") | |
# --- Core Analysis Function (Corrected) --- | |
def analyze_symptoms(symptom_image, symptoms_text): | |
""" | |
Analyzes user's symptoms using the correct prompt format for MedGemma. | |
""" | |
if not model_loaded: | |
return "Error: The AI model could not be loaded. Please check the Space logs." | |
# Standardize input to avoid issues with None or whitespace | |
symptoms_text = symptoms_text.strip() if symptoms_text else "" | |
if symptom_image is None and not symptoms_text: | |
return "Please describe your symptoms or upload an image for analysis." | |
try: | |
# --- CORRECTED PROMPT LOGIC --- | |
# MedGemma expects a specific prompt format with special tokens. | |
# We build this prompt string dynamically. | |
# This is the instruction part of the prompt | |
instruction = ( | |
"You are an expert, empathetic AI medical assistant. " | |
"Analyze the potential medical condition based on the following information. " | |
"Provide a list of possible conditions, your reasoning, and a clear, " | |
"actionable next-steps plan. Start your analysis by describing the user-provided " | |
"information (text and/or image)." | |
) | |
# Build the final prompt based on user inputs | |
prompt_parts = ["<start_of_turn>user"] | |
if symptoms_text: | |
prompt_parts.append(symptoms_text) | |
# The <image> token is a placeholder that tells the model where to "look" at the image. | |
if symptom_image: | |
prompt_parts.append("<image>") | |
prompt_parts.append(instruction) | |
prompt_parts.append("<start_of_turn>model") | |
prompt = "\n".join(prompt_parts) | |
print("Generating pipeline output...") | |
# --- CORRECTED PIPELINE CALL --- | |
# The pipeline expects the prompt string and an 'images' argument (if an image is provided). | |
# We create a dictionary for keyword arguments to pass to the pipeline. | |
pipeline_kwargs = { | |
"max_new_tokens": 512, | |
"do_sample": True, | |
"temperature": 0.7 | |
} | |
# The `images` argument should be a list of PIL Images. | |
if symptom_image: | |
output = pipe(prompt, images=[symptom_image], **pipeline_kwargs) | |
else: | |
# If no image is provided, we do not include the `images` argument in the call. | |
output = pipe(prompt, **pipeline_kwargs) | |
print("Pipeline Output:", output) | |
# --- SIMPLIFIED OUTPUT PROCESSING --- | |
# The pipeline returns a list with one dictionary. The result is in the 'generated_text' key. | |
if output and isinstance(output, list) and 'generated_text' in output[0]: | |
# We extract just the model's response part of the generated text. | |
full_text = output[0]['generated_text'] | |
# The model output includes the prompt, so we split it to get only the new part. | |
result = full_text.split("<start_of_turn>model\n")[-1] | |
else: | |
result = "The model did not return a valid response. Please try again." | |
disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***" | |
return result + disclaimer | |
except Exception as e: | |
print(f"An exception occurred during analysis: {type(e).__name__}: {e}") | |
# Provide a more user-friendly error message | |
return f"An error occurred during analysis. Please check the logs for details: {str(e)}" | |
# --- Create the Gradio Interface (No changes needed here) --- | |
with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo: | |
gr.HTML(""" | |
<div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;"> | |
<h1>๐ฉบ AI Symptom Analyzer</h1> | |
<p>Advanced symptom analysis powered by Google's MedGemma AI</p> | |
</div> | |
""") | |
gr.HTML(""" | |
<div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;"> | |
<strong>โ ๏ธ Medical Disclaimer:</strong> This AI tool is for informational purposes only and is not a substitute for professional medical diagnosis or treatment. | |
</div> | |
""") | |
with gr.Row(equal_height=True): | |
with gr.Column(scale=1): | |
gr.Markdown("### 1. Describe Your Symptoms") | |
symptoms_input = gr.Textbox( | |
label="Symptoms", | |
placeholder="e.g., 'I have a rash on my arm that is red and itchy...'", lines=5) | |
gr.Markdown("### 2. Upload an Image (Optional)") | |
image_input = gr.Image(label="Symptom Image", type="pil", height=300) | |
with gr.Row(): | |
clear_btn = gr.Button("๐๏ธ Clear All", variant="secondary") | |
analyze_btn = gr.Button("๐ Analyze Symptoms", variant="primary", size="lg") | |
with gr.Column(scale=1): | |
gr.Markdown("### ๐ Analysis Report") | |
output_text = gr.Textbox( | |
label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...") | |
def clear_all(): | |
return None, "", "" | |
analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text) | |
clear_btn.click(fn=clear_all, outputs=[image_input, symptoms_input, output_text]) | |
if __name__ == "__main__": | |
print("Starting Gradio interface...") | |
demo.launch(debug=True) |