Spaces:
Running
Running
File size: 5,882 Bytes
83ff66a b67fca4 83ff66a bd084e6 83ff66a bd084e6 83ff66a b67fca4 83ff66a b67fca4 83ff66a bd084e6 83ff66a 2166c8b b67fca4 83ff66a b67fca4 83ff66a b67fca4 83ff66a b67fca4 a2c1346 b67fca4 bd084e6 b67fca4 a2c1346 b67fca4 a2c1346 b67fca4 a2c1346 b67fca4 a2c1346 bd084e6 a2c1346 bd084e6 a2c1346 bd084e6 b67fca4 2166c8b b67fca4 2166c8b b67fca4 8ec5a02 b67fca4 bd084e6 b67fca4 a2c1346 b67fca4 a2c1346 b67fca4 a2c1346 83ff66a b67fca4 83ff66a b67fca4 bd084e6 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces # <-- FIX 1: IMPORT SPACES
# Get the Hugging Face token from the environment variables
# Make sure to set this as a "Secret" in your Hugging Face Space settings
hf_token = os.environ.get("HF_TOKEN")
# Initialize the processor and model
# We are using MedGemma, a 4B parameter model specialized for medical text and images.
model_id = "google/medgemma-4b-it"
# Check for GPU availability and set the data type accordingly
# Using bfloat16 for better performance on compatible GPUs.
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
# Fallback to float16 if bfloat16 is not available
dtype = torch.float16
model_loaded = False
# Load the processor and model from Hugging Face
try:
# AutoProcessor handles both text tokenization and image processing
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
# We will display an error in the UI if the model fails to load.
# This is the core function for the chatbot
@spaces.GPU
def symptom_checker_chat(user_input, history, image_input):
"""
Manages the conversational flow for the symptom checker.
"""
if not model_loaded:
history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs."))
return history, history, None, ""
system_prompt = """
You are an expert, empathetic AI medical assistant... (rest of your prompt is fine)
***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.***
"""
# Create the text input for the model. This will include the <image> token if needed.
# This is what the model sees.
if image_input:
model_input_text = f"<image>\n{user_input}"
else:
model_input_text = user_input
# Construct the conversation history for the model
# We use the full 'model_input_text' from previous turns if available.
conversation = [{"role": "system", "content": system_prompt}]
for turn_input, assistant_output in history:
# The history now stores the text that was ACTUALLY sent to the model
conversation.append({"role": "user", "content": turn_input})
if assistant_output:
conversation.append({"role": "assistant", "content": assistant_output})
# Add the current user turn
conversation.append({"role": "user", "content": model_input_text})
# --- FIX 1: Use the main processor for templating, not the tokenizer ---
# This correctly handles multimodal (text + image) conversations.
prompt = processor.apply_chat_template(
conversation,
tokenize=False,
add_generation_prompt=True
)
# Process inputs, including the image if it exists
if image_input:
inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype)
else:
inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype)
# Generate the output from the model
try:
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Error during model generation: {e}")
clean_response = "An error occurred while generating the response. Please check the logs."
# --- FIX 2: Store the correct information in history ---
# The history state should contain the text with the <image> token to maintain context...
history.append((model_input_text, clean_response))
# ... but the chatbot UI should display the clean user input.
# We construct the display history here.
display_history = []
for turn_input, assistant_output in history:
# Strip the <image> token for display purposes
display_input = turn_input.replace("<image>\n", "")
display_history.append((display_input, assistant_output))
# Return values: update chatbot UI, update history state, clear image, clear textbox
return display_history, history, None, ""
# --- Modify the Gradio part to match the new return signature ---
# ... (your gr.Blocks setup is the same) ...
# The Chatbot component and State
chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png"))
chat_history = gr.State([]) # This will store the history WITH the <image> tokens
# ... (image_box, text_box, buttons are the same) ...
# Function to clear all inputs
def clear_all():
# Now it clears both the display and the state
return [], [], None, ""
clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False)
# Define what happens when the user submits
# The outputs list is slightly different now.
submit_btn.click(
fn=symptom_checker_chat,
inputs=[text_box, chat_history, image_box],
# chatbot gets the display_history, chat_history gets the model history
outputs=[chatbot, chat_history, image_box, text_box]
)
text_box.submit(
fn=symptom_checker_chat,
inputs=[text_box, chat_history, image_box],
outputs=[chatbot, chat_history, image_box, text_box]
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch(debug=True) # Debug mode for more detailed logs
|