File size: 5,882 Bytes
83ff66a
 
b67fca4
 
83ff66a
bd084e6
83ff66a
 
bd084e6
83ff66a
 
b67fca4
 
83ff66a
 
b67fca4
 
 
83ff66a
 
bd084e6
 
83ff66a
2166c8b
b67fca4
83ff66a
b67fca4
 
83ff66a
 
 
 
 
 
 
 
b67fca4
 
83ff66a
b67fca4
a2c1346
b67fca4
 
 
 
 
 
bd084e6
b67fca4
 
a2c1346
b67fca4
 
 
a2c1346
 
 
 
 
 
 
b67fca4
a2c1346
b67fca4
a2c1346
 
 
 
 
bd084e6
a2c1346
 
bd084e6
a2c1346
 
 
bd084e6
 
b67fca4
 
2166c8b
b67fca4
 
 
2166c8b
b67fca4
8ec5a02
b67fca4
 
 
bd084e6
 
 
b67fca4
 
 
 
 
a2c1346
 
 
b67fca4
a2c1346
 
 
 
 
 
 
b67fca4
a2c1346
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
83ff66a
b67fca4
83ff66a
b67fca4
bd084e6
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces # <-- FIX 1: IMPORT SPACES

# Get the Hugging Face token from the environment variables
# Make sure to set this as a "Secret" in your Hugging Face Space settings
hf_token = os.environ.get("HF_TOKEN")

# Initialize the processor and model
# We are using MedGemma, a 4B parameter model specialized for medical text and images.
model_id = "google/medgemma-4b-it"

# Check for GPU availability and set the data type accordingly
# Using bfloat16 for better performance on compatible GPUs.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
    dtype = torch.bfloat16
else:
    # Fallback to float16 if bfloat16 is not available
    dtype = torch.float16

model_loaded = False
# Load the processor and model from Hugging Face
try:
    # AutoProcessor handles both text tokenization and image processing
    processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        torch_dtype=dtype,
        device_map="auto",
    )
    model_loaded = True
except Exception as e:
    print(f"Error loading model: {e}")
    # We will display an error in the UI if the model fails to load.

# This is the core function for the chatbot
@spaces.GPU
def symptom_checker_chat(user_input, history, image_input):
    """
    Manages the conversational flow for the symptom checker.
    """
    if not model_loaded:
        history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
        return history, history, None, ""

    system_prompt = """
    You are an expert, empathetic AI medical assistant... (rest of your prompt is fine)
    ***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
    """
    
    # Create the text input for the model. This will include the <image> token if needed.
    # This is what the model sees.
    if image_input:
        model_input_text = f"<image>\n{user_input}"
    else:
        model_input_text = user_input

    # Construct the conversation history for the model
    # We use the full 'model_input_text' from previous turns if available.
    conversation = [{"role": "system", "content": system_prompt}]
    for turn_input, assistant_output in history:
        # The history now stores the text that was ACTUALLY sent to the model
        conversation.append({"role": "user", "content": turn_input})
        if assistant_output:
            conversation.append({"role": "assistant", "content": assistant_output})
            
    # Add the current user turn
    conversation.append({"role": "user", "content": model_input_text})
        
    # --- FIX 1: Use the main processor for templating, not the tokenizer ---
    # This correctly handles multimodal (text + image) conversations.
    prompt = processor.apply_chat_template(
        conversation,
        tokenize=False,
        add_generation_prompt=True
    )

    # Process inputs, including the image if it exists
    if image_input:
        inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype)
    else:
        inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)

    # Generate the output from the model
    try:
        outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
        input_token_len = inputs["input_ids"].shape[1]
        generated_tokens = outputs[:, input_token_len:]
        clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()

    except Exception as e:
        print(f"Error during model generation: {e}")
        clean_response = "An error occurred while generating the response. Please check the logs."

    # --- FIX 2: Store the correct information in history ---
    # The history state should contain the text with the <image> token to maintain context...
    history.append((model_input_text, clean_response))
    
    # ... but the chatbot UI should display the clean user input.
    # We construct the display history here.
    display_history = []
    for turn_input, assistant_output in history:
        # Strip the <image> token for display purposes
        display_input = turn_input.replace("<image>\n", "")
        display_history.append((display_input, assistant_output))
    
    # Return values: update chatbot UI, update history state, clear image, clear textbox
    return display_history, history, None, ""

# --- Modify the Gradio part to match the new return signature ---
# ... (your gr.Blocks setup is the same) ...

# The Chatbot component and State
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png"))
chat_history = gr.State([]) # This will store the history WITH the <image> tokens

# ... (image_box, text_box, buttons are the same) ...

# Function to clear all inputs
def clear_all():
    # Now it clears both the display and the state
    return [], [], None, ""

clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)

# Define what happens when the user submits
# The outputs list is slightly different now.
submit_btn.click(
    fn=symptom_checker_chat,
    inputs=[text_box, chat_history, image_box],
    # chatbot gets the display_history, chat_history gets the model history
    outputs=[chatbot, chat_history, image_box, text_box]
)

text_box.submit(
    fn=symptom_checker_chat,
    inputs=[text_box, chat_history, image_box],
    outputs=[chatbot, chat_history, image_box, text_box]
)


# Launch the Gradio app
if __name__ == "__main__":
    demo.launch(debug=True) # Debug mode for more detailed logs