Spaces:
Sleeping
Sleeping
File size: 7,430 Bytes
83ff66a b67fca4 83ff66a bd084e6 83ff66a bd084e6 83ff66a b67fca4 83ff66a b67fca4 83ff66a bd084e6 83ff66a 2166c8b b67fca4 83ff66a b67fca4 83ff66a b67fca4 83ff66a b67fca4 bd084e6 b67fca4 bd084e6 b67fca4 bd084e6 b67fca4 bd084e6 b67fca4 2166c8b b67fca4 2166c8b b67fca4 8ec5a02 b67fca4 bd084e6 b67fca4 bd084e6 83ff66a b67fca4 83ff66a b67fca4 83ff66a b67fca4 bd084e6 b67fca4 bd084e6 b67fca4 83ff66a b67fca4 83ff66a b67fca4 bd084e6 b67fca4 bd084e6 b67fca4 bd084e6 83ff66a bd084e6 b67fca4 bd084e6 83ff66a b67fca4 83ff66a b67fca4 bd084e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces # <-- FIX 1: IMPORT SPACES
# Get the Hugging Face token from the environment variables
# Make sure to set this as a "Secret" in your Hugging Face Space settings
hf_token = os.environ.get("HF_TOKEN")
# Initialize the processor and model
# We are using MedGemma, a 4B parameter model specialized for medical text and images.
model_id = "google/medgemma-4b-it"
# Check for GPU availability and set the data type accordingly
# Using bfloat16 for better performance on compatible GPUs.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
# Fallback to float16 if bfloat16 is not available
dtype = torch.float16
model_loaded = False
# Load the processor and model from Hugging Face
try:
# AutoProcessor handles both text tokenization and image processing
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
# We will display an error in the UI if the model fails to load.
# This is the core function for the chatbot
@spaces.GPU # <-- FIX 1: ADD THE GPU DECORATOR
def symptom_checker_chat(user_input, history, image_input):
"""
Manages the conversational flow for the symptom checker.
"""
if not model_loaded:
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
# <-- FIX 3 & 4: Return values match new outputs
return history, history, None, ""
# System prompt to guide the model's behavior
system_prompt = """
You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response.
Here is your workflow:
1. Analyze the user's initial input, which may include text and an image.
2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?").
3. Once you have gathered enough information, provide a list of possible conditions that might align with the symptoms.
4. For each possible condition, briefly explain why it might be relevant.
5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional."
6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:**
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
"""
# Construct the conversation history for the model
conversation = [{"role": "system", "content": system_prompt}]
for user, assistant in history:
conversation.append({"role": "user", "content": user})
if assistant: # Ensure assistant message is not None
conversation.append({"role": "assistant", "content": assistant})
# Add the current user input with a special image token if an image is present
if image_input:
# MedGemma expects the text to start with <image> token if an image is provided
conversation.append({"role": "user", "content": f"<image>\n{user_input}"})
else:
conversation.append({"role": "user", "content": user_input})
# Apply the chat template
prompt = processor.tokenizer.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
inputs = None
# Process inputs, including the image if it exists
if image_input:
inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
# Generate the output from the model
try:
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
# <-- FIX 2: ROBUST RESPONSE PARSING
# Decode only the newly generated tokens, not the whole conversation
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Error during model generation: {e}")
clean_response = "An error occurred while generating the response. Please check the logs."
# Update the history
history.append((user_input, clean_response))
# <-- FIX 3 & 4: Return values to update state, clear image box, and clear text box
return history, history, None, ""
# Create the Gradio Interface using Blocks for more control
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
"""
)
# Chatbot component to display the conversation
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) # Added avatars for fun
# State to store the conversation history
chat_history = gr.State([]) # <-- FIX 3: This state will now be used correctly
with gr.Row():
# Image input
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
# Text input
text_box = gr.Textbox(
label="Describe your symptoms...",
placeholder="e.g., I have a rash on my arm. It's red and itchy.",
scale=4, # Make the textbox wider
)
# Submit button
submit_btn = gr.Button("Send", variant="primary", scale=1)
# Function to clear all inputs
def clear_all():
return [], [], None, "" # <-- FIX 3: Correctly clear the state and chatbot
# Clear button
clear_btn = gr.Button("Start New Conversation")
# <-- FIX 3: The outputs list now correctly targets the state
clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)
# Define what happens when the user submits
submit_btn.click(
fn=symptom_checker_chat,
# <-- FIX 3 & 4: Corrected inputs and outputs
inputs=[text_box, chat_history, image_box],
outputs=[chatbot, chat_history, image_box, text_box]
)
# Define what happens when the user just presses Enter in the textbox
text_box.submit(
fn=symptom_checker_chat,
# <-- FIX 3 & 4: Corrected inputs and outputs
inputs=[text_box, chat_history, image_box],
outputs=[chatbot, chat_history, image_box, text_box]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(debug=True) # Debug mode for more detailed logs
|