import gradio as gr import torch from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image import os import spaces # --- Configuration --- hf_token = os.environ.get("HF_TOKEN") model_id = "google/medgemma-4b-it" # --- Model Loading --- if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: dtype = torch.bfloat16 else: dtype = torch.float16 model_loaded = False try: processor = AutoProcessor.from_pretrained(model_id, token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_id, token=hf_token, torch_dtype=dtype, device_map="auto", ) model_loaded = True print("Model loaded successfully on device:", model.device) except Exception as e: print(f"Error loading model: {e}") # --- Core Chatbot Function --- @spaces.GPU def symptom_checker_chat(user_input, history_state, new_image_upload, image_state): """ Manages the conversational flow by manually building the prompt to ensure correct handling of the token. """ if not model_loaded: history_state.append((user_input, "Error: The model could not be loaded.")) return history_state, history_state, None, None, "" current_image = new_image_upload if new_image_upload is not None else image_state # --- FIX: Manual Prompt Construction --- # This gives us full control and bypasses the opaque apply_chat_template behavior. # System prompt is not included in the turns, but as a prefix. system_prompt = "You are an expert, empathetic AI medical assistant..." # Keep your full system prompt # Build the prompt from history prompt_parts = [] for turn_input, assistant_output in history_state: # Add a user turn from history prompt_parts.append(f"user\n{turn_input}\n") # Add a model turn from history if assistant_output: prompt_parts.append(f"model\n{assistant_output}\n") # Add the current user turn prompt_parts.append("user\n") # The MOST IMPORTANT PART: Add the token if an image is present. # We add it for a new upload OR if we're in a conversation that already had an image. if current_image: prompt_parts.append("\n") prompt_parts.append(f"{user_input}\n") # Add the generation prompt for the model to start its response prompt_parts.append("model\n") # Join everything into a single string final_prompt = "".join(prompt_parts) try: # Process the inputs using our manually built prompt if current_image: inputs = processor(text=final_prompt, images=[current_image], return_tensors="pt").to(model.device, dtype) else: inputs = processor(text=final_prompt, return_tensors="pt").to(model.device, dtype) # Generate the response outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) input_token_len = inputs["input_ids"].shape[1] generated_tokens = outputs[:, input_token_len:] clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip() except Exception as e: print(f"Caught a critical exception during generation: {e}", flush=True) # Display the real error in the UI for easier debugging clean_response = ( "An error occurred during generation. This is the technical details:\n\n" f"```\n{type(e).__name__}: {e}\n```" ) # --- History Management --- # For history, we need to save the user_input along with a marker if an image was present # We use the same \n token we've been using as that marker. history_input = user_input if current_image: history_input = f"\n{user_input}" history_state.append((history_input, clean_response)) # Create display history without the special tokens display_history = [(turn.replace("\n", ""), resp) for turn, resp in history_state] # Return all updated values return display_history, history_state, current_image, None, "" # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: gr.Markdown( """ # AI Symptom Checker powered by MedGemma Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan. """ ) image_state = gr.State(None) chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False) chat_history = gr.State([]) with gr.Row(): image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") with gr.Row(): text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4) submit_btn = gr.Button("Send", variant="primary", scale=1) def clear_all(): return [], [], None, None, "" clear_btn = gr.Button("Start New Conversation") clear_btn.click( fn=clear_all, outputs=[chatbot, chat_history, image_state, image_box, text_box], queue=False ) def on_submit(user_input, history, new_image, persisted_image): # We need to handle the case where the user input is empty if not user_input.strip(): return history, history, persisted_image, None, "" return symptom_checker_chat(user_input, history, new_image, persisted_image) submit_btn.click( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chatbot, chat_history, image_state, image_box, text_box] ) text_box.submit( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chatbot, chat_history, image_state, image_box, text_box] ) if __name__ == "__main__": demo.launch(debug=True)