Spaces:
Running
Running
File size: 6,160 Bytes
83ff66a b67fca4 83ff66a b0565c1 83ff66a b0565c1 83ff66a b0565c1 b67fca4 83ff66a 6ef5bdf 83ff66a 2166c8b 83ff66a b67fca4 83ff66a 6ef5bdf 83ff66a b0565c1 83ff66a b67fca4 83ff66a b0565c1 a2c1346 6ef5bdf b67fca4 b47c12e b67fca4 6ef5bdf b0565c1 6ef5bdf b47c12e b0565c1 b47c12e a2c1346 b47c12e 2166c8b b67fca4 b47c12e 6ef5bdf b47c12e b0565c1 b47c12e b0565c1 46668b2 b67fca4 bd084e6 b67fca4 46668b2 b47c12e 46668b2 b47c12e 6ef5bdf b67fca4 b47c12e 6ef5bdf b0565c1 b47c12e b0565c1 46668b2 b0565c1 6ef5bdf b0565c1 46668b2 b0565c1 6ef5bdf b0565c1 6ef5bdf b0565c1 6ef5bdf b47c12e 6ef5bdf b0565c1 6ef5bdf b0565c1 6ef5bdf b0565c1 83ff66a 6ef5bdf |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 |
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces
# --- Configuration ---
hf_token = os.environ.get("HF_TOKEN")
model_id = "google/medgemma-4b-it"
# --- Model Loading ---
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
model_loaded = False
try:
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
print("Model loaded successfully on device:", model.device)
except Exception as e:
print(f"Error loading model: {e}")
# --- Core Chatbot Function ---
@spaces.GPU
def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
"""
Manages the conversational flow by manually building the prompt to ensure
correct handling of the <image> token.
"""
if not model_loaded:
history_state.append((user_input, "Error: The model could not be loaded."))
return history_state, history_state, None, None, ""
current_image = new_image_upload if new_image_upload is not None else image_state
# --- FIX: Manual Prompt Construction ---
# This gives us full control and bypasses the opaque apply_chat_template behavior.
# System prompt is not included in the turns, but as a prefix.
system_prompt = "You are an expert, empathetic AI medical assistant..." # Keep your full system prompt
# Build the prompt from history
prompt_parts = []
for turn_input, assistant_output in history_state:
# Add a user turn from history
prompt_parts.append(f"<start_of_turn>user\n{turn_input}<end_of_turn>\n")
# Add a model turn from history
if assistant_output:
prompt_parts.append(f"<start_of_turn>model\n{assistant_output}<end_of_turn>\n")
# Add the current user turn
prompt_parts.append("<start_of_turn>user\n")
# The MOST IMPORTANT PART: Add the <image> token if an image is present.
# We add it for a new upload OR if we're in a conversation that already had an image.
if current_image:
prompt_parts.append("<image>\n")
prompt_parts.append(f"{user_input}<end_of_turn>\n")
# Add the generation prompt for the model to start its response
prompt_parts.append("<start_of_turn>model\n")
# Join everything into a single string
final_prompt = "".join(prompt_parts)
try:
# Process the inputs using our manually built prompt
if current_image:
inputs = processor(text=final_prompt, images=[current_image], return_tensors="pt").to(model.device, dtype)
else:
inputs = processor(text=final_prompt, return_tensors="pt").to(model.device, dtype)
# Generate the response
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
# Display the real error in the UI for easier debugging
clean_response = (
"An error occurred during generation. This is the technical details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
# --- History Management ---
# For history, we need to save the user_input along with a marker if an image was present
# We use the same <image>\n token we've been using as that marker.
history_input = user_input
if current_image:
history_input = f"<image>\n{user_input}"
history_state.append((history_input, clean_response))
# Create display history without the special tokens
display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
# Return all updated values
return display_history, history_state, current_image, None, ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(
fn=clear_all,
outputs=[chatbot, chat_history, image_state, image_box, text_box],
queue=False
)
def on_submit(user_input, history, new_image, persisted_image):
# We need to handle the case where the user input is empty
if not user_input.strip():
return history, history, persisted_image, None, ""
return symptom_checker_chat(user_input, history, new_image, persisted_image)
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chatbot, chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chatbot, chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)
|