File size: 4,857 Bytes
83ff66a
c5882f3
b67fca4
77793f4
83ff66a
b0565c1
83ff66a
50aaa9b
c5882f3
83ff66a
c5882f3
 
 
998c789
6ef5bdf
77793f4
83ff66a
 
c5882f3
83ff66a
998c789
b67fca4
83ff66a
50aaa9b
998c789
a91bbfc
b67fca4
50aaa9b
b67fca4
 
998c789
bc69d2f
 
 
998c789
77793f4
 
c5882f3
 
 
 
 
60665db
d305e52
c5882f3
 
9f24600
c5882f3
 
 
4334aa5
c5882f3
 
4334aa5
c5882f3
50aaa9b
c5882f3
50aaa9b
33d4002
d305e52
33d4002
82620d4
c5882f3
 
82620d4
50aaa9b
c5882f3
 
998c789
 
2588693
77793f4
 
3d9624f
60665db
998c789
50aaa9b
998c789
 
9f24600
998c789
 
 
 
 
 
 
 
 
9c4076b
998c789
 
 
 
 
ea22a67
998c789
ea22a67
998c789
 
 
 
 
 
 
ea22a67
998c789
50aaa9b
ea22a67
c5882f3
 
9c4076b
 
998c789
82620d4
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces

# --- Initialize the Model Pipeline ---
print("Loading MedGemma model...")
try:
    pipe = pipeline(
        "image-text-to-text",
        model="google/medgemma-4b-it",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=os.environ.get("HF_TOKEN")
    )
    model_loaded = True
    print("Model loaded successfully!")
except Exception as e:
    model_loaded = False
    print(f"Error loading model: {e}")

# --- Core Analysis Function ---
@spaces.GPU()
def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
    """
    Analyzes user's symptoms using the definitive calling convention.
    """
    if not model_loaded:
        return "Error: The AI model could not be loaded. Please check the Space logs."

    symptoms_text = symptoms_text.strip() if symptoms_text else ""
    if symptom_image is None and not symptoms_text:
        return "Please describe your symptoms or upload an image for analysis."

    try:
        system_prompt = (
            "You are an expert, empathetic AI medical assistant. Analyze the potential "
            "medical condition based on the following information. Provide a list of "
            "possible conditions, your reasoning, and a clear, actionable next-steps plan. "
            "Start your analysis by describing the user-provided information."
        )

        user_content = []
        user_content.append({"type": "text", "text": symptoms_text})
        
        if symptom_image:
            user_content.append({"type": "image", "image": symptom_image})
            
        messages = [
            {"role": "system", "content": [{"type": "text", "text": system_prompt}]},
            {"role": "user", "content": user_content}
        ]
        
        # *** THE FIX: Increased the token limit to prevent truncated output ***
        generation_args = {
            "max_new_tokens": 1024,  # Increased from 512 to 1024
            "do_sample": True,
            "temperature": 0.7,
        }
        
        # The entire messages structure is passed to the `text` argument.
        output = pipe(text=messages, **generation_args)
        
        # Extract the content of the last generated message.
        result = output[0]["generated_text"][-1]["content"]

        disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***"
        
        return result.strip() + disclaimer

    except Exception as e:
        print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
        return f"An error occurred during analysis. Please check the logs for details: {str(e)}"

# --- Gradio Interface (No changes needed) ---
with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
    gr.HTML("""
        <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
            <h1>🩺 AI Symptom Analyzer</h1>
            <p>Advanced symptom analysis powered by Google's MedGemma AI</p>
        </div>
    """)
    gr.HTML("""
        <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
            <strong>⚠️ Medical Disclaimer:</strong> This AI tool is for informational purposes only and is not a substitute for professional medical diagnosis or treatment.
        </div>
    """)
    
    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            gr.Markdown("### 1. Describe Your Symptoms")
            symptoms_input = gr.Textbox(
                label="Symptoms",
                placeholder="e.g., 'I have a rash on my arm that is red and itchy...'", lines=5)
            gr.Markdown("### 2. Upload an Image (Optional)")
            image_input = gr.Image(label="Symptom Image", type="pil", height=300)
            with gr.Row():
                clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
                analyze_btn = gr.Button("πŸ” Analyze Symptoms", variant="primary", size="lg")
                
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“Š Analysis Report")
            output_text = gr.Textbox(
                label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...")

    # Event handlers
    analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text)
    clear_btn.click(fn=lambda: (None, "", ""), outputs=[image_input, symptoms_input, output_text])


if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch(debug=True)