import gradio as gr from transformers import pipeline from PIL import Image import torch import os import spaces # --- Configuration & Model Loading --- # Use the pipeline, which is more robust as seen in the working example print("Loading MedGemma model via pipeline...") try: pipe = pipeline( "image-to-text", # The correct task for this model model="google/medgemma-4b-it", model_kwargs={"torch_dtype": torch.bfloat16}, # Pass dtype here device_map="auto", token=os.environ.get("HF_TOKEN") ) model_loaded = True print("Model loaded successfully!") except Exception as e: model_loaded = False print(f"Error loading model: {e}") # --- Core Chatbot Function --- @spaces.GPU def symptom_checker_chat(user_input, history_for_display, new_image_upload, image_state): """ Manages the conversation using the correct message format derived from the working example. """ if not model_loaded: if user_input: history_for_display.append((user_input, "Error: The model could not be loaded.")) return history_for_display, image_state, None, "" current_image = new_image_upload if new_image_upload is not None else image_state # --- THE CORRECT IMPLEMENTATION --- # Build the 'messages' list using the exact format from the working X-ray app. messages = [] # Optional: System prompt can be added here if needed, following the same format. # Process the conversation history for user_msg, assistant_msg in history_for_display: # For history turns, we assume the image was part of the first turn (handled below). # So, all historical messages are just text. messages.append({"role": "user", "content": [{"type": "text", "text": user_msg}]}) if assistant_msg: messages.append({"role": "assistant", "content": [{"type": "text", "text": assistant_msg}]}) # Add the current user turn current_user_content = [{"type": "text", "text": user_input}] # If there's an image for the conversation, add it to the first user turn's content if current_image is not None and not history_for_display: # Only for the very first message current_user_content.append({"type": "image"}) # The pipeline handles the image object separately messages.append({"role": "user", "content": current_user_content}) try: # Generate analysis using the pipeline. It's much simpler. # We pass the image separately if it exists. if current_image: output = pipe(current_image, prompt=messages, generate_kwargs={"max_new_tokens": 512}) else: # If no image, the pipeline can work with just the prompt output = pipe(prompt=messages, generate_kwargs={"max_new_tokens": 512}) # The pipeline's output structure can be complex; we need to extract the final text. # It's usually in the last dictionary of the generated list. result = output[0]["generated_text"] if isinstance(result, list): # Find the last text content from the model's response clean_response = next((item['text'] for item in reversed(result) if item['type'] == 'text'), "Sorry, I couldn't generate a response.") else: # Simpler text-only output clean_response = result except Exception as e: print(f"Caught a critical exception during generation: {e}", flush=True) clean_response = ( "An error occurred during generation. Details:\n\n" f"```\n{type(e).__name__}: {e}\n```" ) history_for_display.append((user_input, clean_response)) return history_for_display, current_image, None, "" # --- Gradio Interface (Mostly unchanged) --- with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: gr.Markdown( """ # AI Symptom Checker powered by MedGemma Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed. """ ) image_state = gr.State(None) chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False) chat_history = gr.State([]) with gr.Row(): image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") with gr.Row(): text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4) submit_btn = gr.Button("Send", variant="primary", scale=1) def clear_all(): return [], None, None, "" clear_btn = gr.Button("Start New Conversation") clear_btn.click(fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False) def on_submit(user_input, display_history, new_image, persisted_image): if not user_input.strip() and not new_image: return display_history, persisted_image, None, "" # The display history IS our history state now return symptom_checker_chat(user_input, display_history, new_image, persisted_image) submit_btn.click( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) text_box.submit( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) if __name__ == "__main__": demo.launch(debug=True) # Generate the response outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) # Decode only the newly generated part input_token_len = inputs["input_ids"].shape[1] generated_tokens = outputs[:, input_token_len:] clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip() except Exception as e: print(f"Caught a critical exception during generation: {e}", flush=True) # Display the real error in the UI for easier debugging clean_response = ( "An error occurred during generation. This is the technical details:\n\n" f"```\n{type(e).__name__}: {e}\n```" ) # Update the display history history_for_display.append((user_input, clean_response)) # Return all updated values return history_for_display, current_image, None, "" # --- Gradio Interface --- with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: gr.Markdown( """ # AI Symptom Checker powered by MedGemma Describe your symptoms below. For visual symptoms (e.g., a skin rash), upload an image. The AI will analyze the inputs and ask clarifying questions if needed. """ ) # State to hold the image across an entire conversation image_state = gr.State(None) chatbot = gr.Chatbot(label="Conversation", height=500, bubble_full_width=False, avatar_images=("user.png", "bot.png")) # The history state will now just be for display, a simple list of (text, text) tuples. chat_history = gr.State([]) with gr.Row(): image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") with gr.Row(): text_box = gr.Textbox(label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm...", scale=4) submit_btn = gr.Button("Send", variant="primary", scale=1) # The clear function now resets all three states def clear_all(): return [], None, None, "" clear_btn = gr.Button("Start New Conversation") clear_btn.click( fn=clear_all, outputs=[chat_history, image_state, image_box, text_box], queue=False ) # The submit handler function def on_submit(user_input, display_history, new_image, persisted_image): if not user_input.strip() and not new_image: return display_history, persisted_image, None, "" return symptom_checker_chat(user_input, display_history, new_image, persisted_image) # Wire up the events submit_btn.click( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) text_box.submit( fn=on_submit, inputs=[text_box, chat_history, image_box, image_state], outputs=[chat_history, image_state, image_box, text_box] ) if __name__ == "__main__": demo.launch(debug=True)