Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoProcessor, AutoModelForCausalLM | |
from PIL import Image | |
import os | |
import spaces # <-- FIX 1: IMPORT SPACES | |
# Get the Hugging Face token from the environment variables | |
# Make sure to set this as a "Secret" in your Hugging Face Space settings | |
hf_token = os.environ.get("HF_TOKEN") | |
# Initialize the processor and model | |
# We are using MedGemma, a 4B parameter model specialized for medical text and images. | |
model_id = "google/medgemma-4b-it" | |
# Check for GPU availability and set the data type accordingly | |
# Using bfloat16 for better performance on compatible GPUs. | |
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: | |
dtype = torch.bfloat16 | |
else: | |
# Fallback to float16 if bfloat16 is not available | |
dtype = torch.float16 | |
model_loaded = False | |
# Load the processor and model from Hugging Face | |
try: | |
# AutoProcessor handles both text tokenization and image processing | |
processor = AutoProcessor.from_pretrained(model_id, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
torch_dtype=dtype, | |
device_map="auto", | |
) | |
model_loaded = True | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
# We will display an error in the UI if the model fails to load. | |
# This is the core function for the chatbot | |
def symptom_checker_chat(user_input, history, image_input): | |
""" | |
Manages the conversational flow for the symptom checker. | |
""" | |
if not model_loaded: | |
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs.")) | |
return history, history, None, "" | |
system_prompt = """ | |
You are an expert, empathetic AI medical assistant... (rest of your prompt is fine) | |
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.*** | |
""" | |
# Create the text input for the model. This will include the <image> token if needed. | |
# This is what the model sees. | |
if image_input: | |
model_input_text = f"<image>\n{user_input}" | |
else: | |
model_input_text = user_input | |
# Construct the conversation history for the model | |
# We use the full 'model_input_text' from previous turns if available. | |
conversation = [{"role": "system", "content": system_prompt}] | |
for turn_input, assistant_output in history: | |
# The history now stores the text that was ACTUALLY sent to the model | |
conversation.append({"role": "user", "content": turn_input}) | |
if assistant_output: | |
conversation.append({"role": "assistant", "content": assistant_output}) | |
# Add the current user turn | |
conversation.append({"role": "user", "content": model_input_text}) | |
# --- FIX 1: Use the main processor for templating, not the tokenizer --- | |
# This correctly handles multimodal (text + image) conversations. | |
prompt = processor.apply_chat_template( | |
conversation, | |
tokenize=False, | |
add_generation_prompt=True | |
) | |
# Process inputs, including the image if it exists | |
if image_input: | |
inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype) | |
else: | |
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype) | |
# Generate the output from the model | |
try: | |
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) | |
input_token_len = inputs["input_ids"].shape[1] | |
generated_tokens = outputs[:, input_token_len:] | |
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip() | |
except Exception as e: | |
print(f"Error during model generation: {e}") | |
clean_response = "An error occurred while generating the response. Please check the logs." | |
# --- FIX 2: Store the correct information in history --- | |
# The history state should contain the text with the <image> token to maintain context... | |
history.append((model_input_text, clean_response)) | |
# ... but the chatbot UI should display the clean user input. | |
# We construct the display history here. | |
display_history = [] | |
for turn_input, assistant_output in history: | |
# Strip the <image> token for display purposes | |
display_input = turn_input.replace("<image>\n", "") | |
display_history.append((display_input, assistant_output)) | |
# Return values: update chatbot UI, update history state, clear image, clear textbox | |
return display_history, history, None, "" | |
# --- Modify the Gradio part to match the new return signature --- | |
# ... (your gr.Blocks setup is the same) ... | |
# The Chatbot component and State | |
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) | |
chat_history = gr.State([]) # This will store the history WITH the <image> tokens | |
# ... (image_box, text_box, buttons are the same) ... | |
# Function to clear all inputs | |
def clear_all(): | |
# Now it clears both the display and the state | |
return [], [], None, "" | |
clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False) | |
# Define what happens when the user submits | |
# The outputs list is slightly different now. | |
submit_btn.click( | |
fn=symptom_checker_chat, | |
inputs=[text_box, chat_history, image_box], | |
# chatbot gets the display_history, chat_history gets the model history | |
outputs=[chatbot, chat_history, image_box, text_box] | |
) | |
text_box.submit( | |
fn=symptom_checker_chat, | |
inputs=[text_box, chat_history, image_box], | |
outputs=[chatbot, chat_history, image_box, text_box] | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch(debug=True) # Debug mode for more detailed logs | |