import gradio as gr from transformers import pipeline from PIL import Image import torch import os import spaces # --- Initialize the Model Pipeline --- print("Loading MedGemma model...") try: pipe = pipeline( "image-text-to-text", model="google/medgemma-4b-it", torch_dtype=torch.bfloat16, device_map="auto", token=os.environ.get("HF_TOKEN") ) model_loaded = True print("Model loaded successfully!") except Exception as e: model_loaded = False print(f"Error loading model: {e}") # --- Core Analysis Function --- @spaces.GPU() def analyze_symptoms(symptom_image, symptoms_text): """ Analyzes user's symptoms using a corrected prompt-building logic. """ if not model_loaded: return "Error: The AI model could not be loaded. Please check the Space logs." # Standardize input to avoid issues with None or whitespace symptoms_text = symptoms_text.strip() if symptoms_text else "" if symptom_image is None and not symptoms_text: return "Please describe your symptoms or upload an image for analysis." try: # --- REVISED PROMPT LOGIC --- # Build the prompt dynamically based on provided inputs. # This is much clearer and less error-prone. prompt_parts = [ "You are an expert, empathetic AI medical assistant. Analyze the potential medical condition based on the following information.", "Provide a list of possible conditions, your reasoning, and a clear, actionable next-steps plan.", "Start your analysis by describing the user-provided information (text and/or image)." ] # This is the actual user input that the model will process. # It's better to pass it directly instead of wrapping it in another instruction. user_input_for_model = [] if symptoms_text: user_input_for_model.append({"type": "text", "text": symptoms_text}) if symptom_image: # The pipeline expects an image object. PIL Image is correct. user_input_for_model.append({"type": "image", "image": symptom_image}) # The system prompt sets the context and instructions for the AI. system_prompt = " ".join(prompt_parts) messages = [ { "role": "system", "content": [{"type": "text", "text": system_prompt}] }, { "role": "user", "content": user_input_for_model } ] print("Generating pipeline output...") output = pipe(messages, max_new_tokens=512, do_sample=True, temperature=0.7) # The output format is a list containing the full conversation history. # The last message in the list is the AI's response. print("Pipeline Output:", output) # Make the output processing more robust generated = output[0]["generated_text"] if isinstance(generated, list) and generated: # If output is a list of dicts, take the content from the last one result = generated[-1].get('content', str(generated)) elif isinstance(generated, str): # If output is just a string result = generated else: # Failsafe for any other unexpected format result = str(generated) disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***" return result + disclaimer except Exception as e: print(f"An exception occurred during analysis: {type(e).__name__}: {e}") return f"Error during analysis: {str(e)}" # --- Create the Gradio Interface (No changes needed here) --- with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo: gr.HTML("""

🩺 AI Symptom Analyzer

Advanced symptom analysis powered by Google's MedGemma AI

""") gr.HTML("""
⚠️ Medical Disclaimer: This AI tool is for informational purposes only and is not a substitute for professional medical diagnosis or treatment.
""") with gr.Row(equal_height=True): with gr.Column(scale=1): gr.Markdown("### 1. Describe Your Symptoms") symptoms_input = gr.Textbox( label="Symptoms", placeholder="e.g., 'I have a rash on my arm that is red and itchy...'", lines=5) gr.Markdown("### 2. Upload an Image (Optional)") image_input = gr.Image(label="Symptom Image", type="pil", height=300) with gr.Row(): clear_btn = gr.Button("🗑️ Clear All", variant="secondary") analyze_btn = gr.Button("🔍 Analyze Symptoms", variant="primary", size="lg") with gr.Column(scale=1): gr.Markdown("### 📊 Analysis Report") output_text = gr.Textbox( label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...") def clear_all(): return None, "", "" analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text) clear_btn.click(fn=clear_all, outputs=[image_input, symptoms_input, output_text]) if __name__ == "__main__": print("Starting Gradio interface...") demo.launch(debug=True)