hponepyae's picture
Update app.py
b47c12e verified
raw
history blame
6.16 kB
import gradio as gr
import torch
from transformers import AutoProcessor, AutoModelForCausalLM
from PIL import Image
import os
import spaces
# --- Configuration ---
hf_token = os.environ.get("HF_TOKEN")
model_id = "google/medgemma-4b-it"
# --- Model Loading ---
if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8:
dtype = torch.bfloat16
else:
dtype = torch.float16
model_loaded = False
try:
processor = AutoProcessor.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
print("Model loaded successfully on device:", model.device)
except Exception as e:
print(f"Error loading model: {e}")
# --- Core Chatbot Function ---
@spaces.GPU
def symptom_checker_chat(user_input, history_state, new_image_upload, image_state):
"""
Manages the conversational flow by manually building the prompt to ensure
correct handling of the <image> token.
"""
if not model_loaded:
history_state.append((user_input, "Error: The model could not be loaded."))
return history_state, history_state, None, None, ""
current_image = new_image_upload if new_image_upload is not None else image_state
# --- FIX: Manual Prompt Construction ---
# This gives us full control and bypasses the opaque apply_chat_template behavior.
# System prompt is not included in the turns, but as a prefix.
system_prompt = "You are an expert, empathetic AI medical assistant..." # Keep your full system prompt
# Build the prompt from history
prompt_parts = []
for turn_input, assistant_output in history_state:
# Add a user turn from history
prompt_parts.append(f"<start_of_turn>user\n{turn_input}<end_of_turn>\n")
# Add a model turn from history
if assistant_output:
prompt_parts.append(f"<start_of_turn>model\n{assistant_output}<end_of_turn>\n")
# Add the current user turn
prompt_parts.append("<start_of_turn>user\n")
# The MOST IMPORTANT PART: Add the <image> token if an image is present.
# We add it for a new upload OR if we're in a conversation that already had an image.
if current_image:
prompt_parts.append("<image>\n")
prompt_parts.append(f"{user_input}<end_of_turn>\n")
# Add the generation prompt for the model to start its response
prompt_parts.append("<start_of_turn>model\n")
# Join everything into a single string
final_prompt = "".join(prompt_parts)
try:
# Process the inputs using our manually built prompt
if current_image:
inputs = processor(text=final_prompt, images=[current_image], return_tensors="pt").to(model.device, dtype)
else:
inputs = processor(text=final_prompt, return_tensors="pt").to(model.device, dtype)
# Generate the response
outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7)
input_token_len = inputs["input_ids"].shape[1]
generated_tokens = outputs[:, input_token_len:]
clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip()
except Exception as e:
print(f"Caught a critical exception during generation: {e}", flush=True)
# Display the real error in the UI for easier debugging
clean_response = (
"An error occurred during generation. This is the technical details:\n\n"
f"```\n{type(e).__name__}: {e}\n```"
)
# --- History Management ---
# For history, we need to save the user_input along with a marker if an image was present
# We use the same <image>\n token we've been using as that marker.
history_input = user_input
if current_image:
history_input = f"<image>\n{user_input}"
history_state.append((history_input, clean_response))
# Create display history without the special tokens
display_history = [(turn.replace("<image>\n", ""), resp) for turn, resp in history_state]
# Return all updated values
return display_history, history_state, current_image, None, ""
# --- Gradio Interface ---
with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo:
gr.Markdown(
"""
# AI Symptom Checker powered by MedGemma
Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan.
"""
)
image_state = gr.State(None)
chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False)
chat_history = gr.State([])
with gr.Row():
image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)")
with gr.Row():
text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4)
submit_btn = gr.Button("Send", variant="primary", scale=1)
def clear_all():
return [], [], None, None, ""
clear_btn = gr.Button("Start New Conversation")
clear_btn.click(
fn=clear_all,
outputs=[chatbot, chat_history, image_state, image_box, text_box],
queue=False
)
def on_submit(user_input, history, new_image, persisted_image):
# We need to handle the case where the user input is empty
if not user_input.strip():
return history, history, persisted_image, None, ""
return symptom_checker_chat(user_input, history, new_image, persisted_image)
submit_btn.click(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chatbot, chat_history, image_state, image_box, text_box]
)
text_box.submit(
fn=on_submit,
inputs=[text_box, chat_history, image_box, image_state],
outputs=[chatbot, chat_history, image_state, image_box, text_box]
)
if __name__ == "__main__":
demo.launch(debug=True)