File size: 5,675 Bytes
83ff66a
77793f4
b67fca4
77793f4
83ff66a
b0565c1
83ff66a
a91bbfc
998c789
83ff66a
77793f4
a91bbfc
77793f4
998c789
6ef5bdf
77793f4
83ff66a
 
77793f4
83ff66a
998c789
b67fca4
83ff66a
a91bbfc
998c789
a91bbfc
b67fca4
33d4002
 
b67fca4
 
998c789
bc69d2f
 
 
998c789
77793f4
 
33d4002
a91bbfc
60665db
a91bbfc
60665db
a91bbfc
 
60665db
a91bbfc
 
 
 
 
998c789
a91bbfc
 
 
 
 
 
1fa102b
33d4002
 
 
 
 
 
 
 
 
 
 
 
 
a91bbfc
ea22a67
3d9624f
33d4002
a91bbfc
 
 
 
 
 
 
ea22a67
60665db
998c789
 
 
 
77793f4
 
3d9624f
60665db
998c789
33d4002
998c789
 
 
 
 
 
 
 
 
 
 
 
9c4076b
998c789
 
 
 
 
ea22a67
998c789
ea22a67
998c789
 
 
 
 
 
 
ea22a67
998c789
9c4076b
998c789
9c4076b
ea22a67
 
9c4076b
 
998c789
33d4002
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import gradio as gr
from transformers import pipeline
from PIL import Image
import torch
import os
import spaces

# --- Initialize the Model Pipeline (No changes) ---
print("Loading MedGemma model...")
try:
    pipe = pipeline(
        "image-to-text",
        model="google/medgemma-4b-it",
        torch_dtype=torch.bfloat16,
        device_map="auto",
        token=os.environ.get("HF_TOKEN")
    )
    model_loaded = True
    print("Model loaded successfully!")
except Exception as e:
    model_loaded = False
    print(f"Error loading model: {e}")

# --- Core Analysis Function (Final Corrected Version) ---
@spaces.GPU()
def analyze_symptoms(symptom_image: Image.Image, symptoms_text: str):
    """
    Analyzes user's symptoms using the recommended chat format and correct
    parameter passing for the MedGemma multimodal model.
    """
    if not model_loaded:
        return "Error: The AI model could not be loaded. Please check the Space logs."

    symptoms_text = symptoms_text.strip() if symptoms_text else ""
    if symptom_image is None and not symptoms_text:
        return "Please describe your symptoms or upload an image for analysis."

    try:
        # --- CHAT-BASED PROMPT LOGIC (Unchanged) ---
        system_instruction = (
            "You are an expert, empathetic AI medical assistant. "
            "Analyze the potential medical condition based on the user's input. "
            "Provide a list of possible conditions, your reasoning, and a clear, "
            "actionable next-steps plan. Begin your analysis by describing the information "
            "the user provided."
        )

        user_content = []
        text_to_send = symptoms_text if symptoms_text else "Please analyze this medical image."
        user_content.append({"type": "text", "text": text_to_send})

        if symptom_image:
            user_content.append({"type": "image", "image": symptom_image})

        messages = [
            {"role": "system", "content": system_instruction},
            {"role": "user", "content": user_content},
        ]
        
        print("Generating pipeline output with chat format...")

        # --- DEFINITIVE PIPELINE CALL ---
        # All text-generation parameters must be nested within a 'generate_kwargs' dictionary.
        generate_kwargs = {
            "max_new_tokens": 512,
            "do_sample": True,
            "temperature": 0.7,
        }

        # The `messages` list is the primary argument.
        # `generate_kwargs` is a dedicated keyword argument for generation options.
        output = pipe(messages, generate_kwargs=generate_kwargs)

        print("Pipeline Output:", output)

        # --- OUTPUT PROCESSING (Unchanged) ---
        if output and isinstance(output, list) and output[0].get('generated_text'):
            full_conversation = output[0]['generated_text']
            assistant_message = full_conversation[-1]
            if assistant_message['role'] == 'assistant':
                result = assistant_message['content']
            else:
                result = str(assistant_message)
        else:
            result = "The model did not return a valid response. Please try again."

        disclaimer = "\n\n***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***"
        
        return result + disclaimer

    except Exception as e:
        print(f"An exception occurred during analysis: {type(e).__name__}: {e}")
        return f"An error occurred during analysis. Please check the logs for details: {str(e)}"

# --- Gradio Interface (No changes needed) ---
with gr.Blocks(theme=gr.themes.Soft(), title="AI Symptom Analyzer") as demo:
    gr.HTML("""
        <div style="text-align: center; background: linear-gradient(135deg, #667eea 0%, #764ba2 100%); color: white; padding: 2rem; border-radius: 10px; margin-bottom: 2rem;">
            <h1>🩺 AI Symptom Analyzer</h1>
            <p>Advanced symptom analysis powered by Google's MedGemma AI</p>
        </div>
    """)
    gr.HTML("""
        <div style="background-color: #fff3cd; border: 1px solid #ffeaa7; border-radius: 8px; padding: 1rem; margin: 1rem 0; color: #856404;">
            <strong>⚠️ Medical Disclaimer:</strong> This AI tool is for informational purposes only and is not a substitute for professional medical diagnosis or treatment.
        </div>
    """)
    
    with gr.Row(equal_height=True):
        with gr.Column(scale=1):
            gr.Markdown("### 1. Describe Your Symptoms")
            symptoms_input = gr.Textbox(
                label="Symptoms",
                placeholder="e.g., 'I have a rash on my arm that is red and itchy...'", lines=5)
            gr.Markdown("### 2. Upload an Image (Optional)")
            image_input = gr.Image(label="Symptom Image", type="pil", height=300)
            with gr.Row():
                clear_btn = gr.Button("πŸ—‘οΈ Clear All", variant="secondary")
                analyze_btn = gr.Button("πŸ” Analyze Symptoms", variant="primary", size="lg")
                
        with gr.Column(scale=1):
            gr.Markdown("### πŸ“Š Analysis Report")
            output_text = gr.Textbox(
                label="AI Analysis", lines=25, show_copy_button=True, placeholder="Analysis results will appear here...")

    def clear_all():
        return None, "", ""

    analyze_btn.click(fn=analyze_symptoms, inputs=[image_input, symptoms_input], outputs=output_text)
    clear_btn.click(fn=clear_all, outputs=[image_input, symptoms_input, output_text])

if __name__ == "__main__":
    print("Starting Gradio interface...")
    demo.launch(debug=True)