import gradio as gr from transformers import pipeline from PIL import Image import torch import os import spaces # --- Configuration & Model Loading --- print("Loading MedGemma model via pipeline...") model_loaded = False pipe = None try: # Using the "image-to-text" pipeline is the standard for these models pipe = pipeline( "image-to-text", model="google/medgemma-4b-it", model_kwargs={"torch_dtype": torch.bfloat16}, device_map="auto", token=os.environ.get("HF_TOKEN") ) model_loaded = True print("Model loaded successfully!") except Exception as e: print(f"Error loading model: {e}") # --- Core Chatbot Function --- @spaces.GPU(duration=120) def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state): """ Manages the conversation by correctly separating the image object from the text-based message history in the pipeline call. """ if not model_loaded: if user_input: history_for_display.append((user_input, "Error: The model could not be loaded.")) return history_for_display, image_state, None, "" current_image = new_image_upload if new_image_upload is not None else image_state # 1. Build a simple list of text messages for the conversation history. messages = [] for user_msg, assistant_msg in history_for_display: messages.append({"role": "user", "content": user_msg}) if assistant_msg: messages.append({"role": "assistant", "content": assistant_msg}) # Add the current user's text message messages.append({"role": "user", "content": user_input}) try: # 2. Call the pipeline differently based on whether an image is present. generate_kwargs = {"max_new_tokens": 512, "do_sample": True, "temperature": 0.7} if current_image: # For multimodal calls, the image is the FIRST argument. output = pipe(current_image, prompt=messages, generate_kwargs=generate_kwargs) else: # For text-only calls, we ONLY use the `prompt` keyword argument. output = pipe(prompt=messages, generate_kwargs=generate_kwargs) # 3. Extract the response from the last message in the full conversation history. clean_response = output[0]["generated_text"][-1]['content'] except Exception as e: print(f"Caught a critical exception during generation: {e}", flush=True) clean_response = ( "An error occurred during generation. Details:\n\n" f"```\n{type(e).__name__}: {e}\n```" ) # Update history and return values for the Gradio UI history_for_display.append((user_input, clean_response)) return history_for_display, current_image, None, "" # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: # 1. UI Components are defined first gr.Markdown( """ # AI Symptom Checker powered by MedGemma Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. """ ) image_state = gr.State(None) chat_history = gr.State([]) chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png")) with gr.Row(): image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") with gr.Row(): text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4) submit_btn = gr.Button("Send", variant="primary", scale=1) # The clear button must be defined before its .click event is set up. clear_btn = gr.Button("Clear Conversation") # 2. Helper functions are defined next def clear_all(): # This function returns empty values for all the components it needs to clear. return [], None, None, "", [] def on_submit(user_input, display_history, new_image, persisted_image): # A wrapper function to handle the user submission logic. if not user_input.strip() and not new_image: return display_history, persisted_image, None, "" return symptom_checker_chat(user_input, display_history, new_image, persisted_image) # 3. Event Handlers are defined last, linking components to functions submit_btn.click( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] # Note: The chatbot is updated via the chat_history state ) text_box.submit( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) clear_btn.click( fn=clear_all, # The outputs list must match the values returned by clear_all() outputs=[chat_history, image_state, image_box, text_box, chatbot] ) if __name__ == "__main__": demo.launch(debug=True) with gr.Row(): image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") with gr.Row(): text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm that is red and itchy...", scale=4) submit_btn = gr.Button("Send", variant="primary", scale=1) # The helper functions need to be defined at this indentation level, # inside the `with gr.Blocks` scope but before they are used. def clear_all(): return [], None, None, "" def on_submit(user_input, display_history, new_image, persisted_image): if not user_input.strip() and not new_image: return display_history, persisted_image, None, "" # The return statement is correctly indented inside this function return symptom_checker_chat(user_input, display_history, new_image, persisted_image) # Event Handlers for UI components clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False) submit_btn.click( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) text_box.submit( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) if __name__ == "__main__": demo.launch(debug=True)