hponepyae's picture
Update app.py
a2c1346 verified
raw
history blame
5.88 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces # <-- FIX 1: IMPORT SPACES
# Get the Hugging Face token from the environment variables
# Make sure to set this as a "Secret" in your Hugging Face Space settings
hf_token = os.environ.get("HF_TOKEN")
# Initialize the processor and model
# We are using MedGemma, a 4B parameter model specialized for medical text and images.
model_id = "google/medgemma-4b-it"
# Check for GPU availability and set the data type accordingly
# Using bfloat16 for better performance on compatible GPUs.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
# Fallback to float16 if bfloat16 is not available
dtype = torch.float16
model_loaded = False
# Load the processor and model from Hugging Face
try:
# AutoProcessor handles both text tokenization and image processing
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
# We will display an error in the UI if the model fails to load.
# This is the core function for the chatbot
@spaces.GPU
def symptom_checker_chat(user_input, history, image_input):
"""
Manages the conversational flow for the symptom checker.
"""
if not model_loaded:
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
return history, history, None, ""
system_prompt = """
You are an expert, empathetic AI medical assistant... (rest of your prompt is fine)
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
"""
# Create the text input for the model. This will include the <image> token if needed.
# This is what the model sees.
if image_input:
model_input_text = f"<image>\n{user_input}"
else:
model_input_text = user_input
# Construct the conversation history for the model
# We use the full 'model_input_text' from previous turns if available.
conversation = [{"role": "system", "content": system_prompt}]
for turn_input, assistant_output in history:
# The history now stores the text that was ACTUALLY sent to the model
conversation.append({"role": "user", "content": turn_input})
if assistant_output:
conversation.append({"role": "assistant", "content": assistant_output})
# Add the current user turn
conversation.append({"role": "user", "content": model_input_text})
# --- FIX 1: Use the main processor for templating, not the tokenizer ---
# This correctly handles multimodal (text + image) conversations.
prompt = processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Process inputs, including the image if it exists
if image_input:
inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
# Generate the output from the model
try:
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Error during model generation: {e}")
clean_response = "An error occurred while generating the response. Please check the logs."
# --- FIX 2: Store the correct information in history ---
# The history state should contain the text with the <image> token to maintain context...
history.append((model_input_text, clean_response))
# ... but the chatbot UI should display the clean user input.
# We construct the display history here.
display_history = []
for turn_input, assistant_output in history:
# Strip the <image> token for display purposes
display_input = turn_input.replace("<image>\n", "")
display_history.append((display_input, assistant_output))
# Return values: update chatbot UI, update history state, clear image, clear textbox
return display_history, history, None, ""
# --- Modify the Gradio part to match the new return signature ---
# ... (your gr.Blocks setup is the same) ...
# The Chatbot component and State
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png"))
chat_history = gr.State([]) # This will store the history WITH the <image> tokens
# ... (image_box, text_box, buttons are the same) ...
# Function to clear all inputs
def clear_all():
# Now it clears both the display and the state
return [], [], None, ""
clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)
# Define what happens when the user submits
# The outputs list is slightly different now.
submit_btn.click(
fn=symptom_checker_chat,
inputs=[text_box, chat_history, image_box],
# chatbot gets the display_history, chat_history gets the model history
outputs=[chatbot, chat_history, image_box, text_box]
)
text_box.submit(
fn=symptom_checker_chat,
inputs=[text_box, chat_history, image_box],
outputs=[chatbot, chat_history, image_box, text_box]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(debug=True) # Debug mode for more detailed logs