Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
import os | |
import spaces # <-- FIX 1: IMPORT SPACES | |
# Get the Hugging Face token from the environment variables | |
# Make sure to set this as a "Secret" in your Hugging Face Space settings | |
hf_token = os.environ.get("HF_TOKEN") | |
# Initialize the processor and model | |
# We are using MedGemma, a 4B parameter model specialized for medical text and images. | |
model_id = "google/medgemma-4b-it" | |
# Check for GPU availability and set the data type accordingly | |
# Using bfloat16 for better performance on compatible GPUs. | |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
dtype = torch.bfloat16 | |
else: | |
# Fallback to float16 if bfloat16 is not available | |
dtype = torch.float16 | |
model_loaded = False | |
# Load the processor and model from Hugging Face | |
try: | |
# AutoProcessor handles both text tokenization and image processing | |
processor = AutoProcessor.from_pretrained(model_id, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
torch_dtype=dtype, | |
device_map="auto", | |
) | |
model_loaded = True | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# We will display an error in the UI if the model fails to load. | |
# This is the core function for the chatbot | |
# <-- FIX 1: ADD THE GPU DECORATOR | |
def symptom_checker_chat(user_input, history, image_input): | |
""" | |
Manages the conversational flow for the symptom checker. | |
""" | |
if not model_loaded: | |
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs.")) | |
# <-- FIX 3 & 4: Return values match new outputs | |
return history, history, None, "" | |
# System prompt to guide the model's behavior | |
system_prompt = """ | |
You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response. | |
Here is your workflow: | |
1. Analyze the user's initial input, which may include text and an image. | |
2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?"). | |
3. Once you have gathered enough information, provide a list of possible conditions that might align with the symptoms. | |
4. For each possible condition, briefly explain why it might be relevant. | |
5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional." | |
6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:** | |
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.*** | |
""" | |
# Construct the conversation history for the model | |
conversation = [{"role": "system", "content": system_prompt}] | |
for user, assistant in history: | |
conversation.append({"role": "user", "content": user}) | |
if assistant: # Ensure assistant message is not None | |
conversation.append({"role": "assistant", "content": assistant}) | |
# Add the current user input with a special image token if an image is present | |
if image_input: | |
# MedGemma expects the text to start with <image> token if an image is provided | |
conversation.append({"role": "user", "content": f"<image>\n{user_input}"}) | |
else: | |
conversation.append({"role": "user", "content": user_input}) | |
# Apply the chat template | |
prompt = processor.tokenizer.apply_chat_template( | |
conversation, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
inputs = None | |
# Process inputs, including the image if it exists | |
if image_input: | |
inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype) | |
else: | |
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype) | |
# Generate the output from the model | |
try: | |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) | |
# <-- FIX 2: ROBUST RESPONSE PARSING | |
# Decode only the newly generated tokens, not the whole conversation | |
input_token_len = inputs["input_ids"].shape[1] | |
generated_tokens = outputs[:, input_token_len:] | |
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip() | |
except Exception as e: | |
print(f"Error during model generation: {e}") | |
clean_response = "An error occurred while generating the response. Please check the logs." | |
# Update the history | |
history.append((user_input, clean_response)) | |
# <-- FIX 3 & 4: Return values to update state, clear image box, and clear text box | |
return history, history, None, "" | |
# Create the Gradio Interface using Blocks for more control | |
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: | |
gr.Markdown( | |
""" | |
# AI Symptom Checker powered by MedGemma | |
Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan. | |
""" | |
) | |
# Chatbot component to display the conversation | |
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) # Added avatars for fun | |
# State to store the conversation history | |
chat_history = gr.State([]) # <-- FIX 3: This state will now be used correctly | |
with gr.Row(): | |
# Image input | |
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") | |
with gr.Row(): | |
# Text input | |
text_box = gr.Textbox( | |
label="Describe your symptoms...", | |
placeholder="e.g., I have a rash on my arm. It's red and itchy.", | |
scale=4, # Make the textbox wider | |
) | |
# Submit button | |
submit_btn = gr.Button("Send", variant="primary", scale=1) | |
# Function to clear all inputs | |
def clear_all(): | |
return [], [], None, "" # <-- FIX 3: Correctly clear the state and chatbot | |
# Clear button | |
clear_btn = gr.Button("Start New Conversation") | |
# <-- FIX 3: The outputs list now correctly targets the state | |
clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False) | |
# Define what happens when the user submits | |
submit_btn.click( | |
fn=symptom_checker_chat, | |
# <-- FIX 3 & 4: Corrected inputs and outputs | |
inputs=[text_box, chat_history, image_box], | |
outputs=[chatbot, chat_history, image_box, text_box] | |
) | |
# Define what happens when the user just presses Enter in the textbox | |
text_box.submit( | |
fn=symptom_checker_chat, | |
# <-- FIX 3 & 4: Corrected inputs and outputs | |
inputs=[text_box, chat_history, image_box], | |
outputs=[chatbot, chat_history, image_box, text_box] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(debug=True) # Debug mode for more detailed logs | |