import gradio as gr import torch from transformers import AutoProcessor, AutoModelForCausalLM from PIL import Image import os import spaces # <-- FIX 1: IMPORT SPACES # Get the Hugging Face token from the environment variables # Make sure to set this as a "Secret" in your Hugging Face Space settings hf_token = os.environ.get("HF_TOKEN") # Initialize the processor and model # We are using MedGemma, a 4B parameter model specialized for medical text and images. model_id = "google/medgemma-4b-it" # Check for GPU availability and set the data type accordingly # Using bfloat16 for better performance on compatible GPUs. if torch.cuda.is_available() and torch.cuda.get_device_capability()[0] >= 8: dtype = torch.bfloat16 else: # Fallback to float16 if bfloat16 is not available dtype = torch.float16 model_loaded = False # Load the processor and model from Hugging Face try: # AutoProcessor handles both text tokenization and image processing processor = AutoProcessor.from_pretrained(model_id, token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_id, token=hf_token, torch_dtype=dtype, device_map="auto", ) model_loaded = True except Exception as e: print(f"Error loading model: {e}") # We will display an error in the UI if the model fails to load. # This is the core function for the chatbot @spaces.GPU # <-- FIX 1: ADD THE GPU DECORATOR def symptom_checker_chat(user_input, history, image_input): """ Manages the conversational flow for the symptom checker. """ if not model_loaded: history.append((user_input, "Error: The model could not be loaded. Please check the Hugging Face Space logs.")) # <-- FIX 3 & 4: Return values match new outputs return history, history, None, "" # System prompt to guide the model's behavior system_prompt = """ You are an expert, empathetic AI medical assistant. Your role is to analyze a user's symptoms and provide a helpful, safe, and informative response. Here is your workflow: 1. Analyze the user's initial input, which may include text and an image. 2. If the information is insufficient, ask specific, relevant clarifying questions to better understand the symptoms (e.g., "How long have you had this symptom?", "Can you describe the pain? Is it sharp or dull?"). 3. Once you have gathered enough information, provide a list of possible conditions that might align with the symptoms. 4. For each possible condition, briefly explain why it might be relevant. 5. Provide a clear, actionable plan, such as "It would be best to monitor your symptoms," or "You should consider consulting a healthcare professional." 6. **Crucially, you must ALWAYS end every single response with the following disclaimer, formatted exactly like this, on a new line:** ***Disclaimer: I am an AI assistant and not a medical professional. This is not a diagnosis. Please consult a doctor for any health concerns.*** """ # Construct the conversation history for the model conversation = [{"role": "system", "content": system_prompt}] for user, assistant in history: conversation.append({"role": "user", "content": user}) if assistant: # Ensure assistant message is not None conversation.append({"role": "assistant", "content": assistant}) # Add the current user input with a special image token if an image is present if image_input: # MedGemma expects the text to start with token if an image is provided conversation.append({"role": "user", "content": f"\n{user_input}"}) else: conversation.append({"role": "user", "content": user_input}) # Apply the chat template prompt = processor.tokenizer.apply_chat_template( conversation, tokenize=False, add_generation_prompt=True ) inputs = None # Process inputs, including the image if it exists if image_input: inputs = processor(text=prompt, images=image_input, return_tensors="pt").to(model.device, dtype) else: inputs = processor(text=prompt, return_tensors="pt").to(model.device, dtype) # Generate the output from the model try: outputs = model.generate(**inputs, max_new_tokens=512, do_sample=True, temperature=0.7) # <-- FIX 2: ROBUST RESPONSE PARSING # Decode only the newly generated tokens, not the whole conversation input_token_len = inputs["input_ids"].shape[1] generated_tokens = outputs[:, input_token_len:] clean_response = processor.decode(generated_tokens[0], skip_special_tokens=True).strip() except Exception as e: print(f"Error during model generation: {e}") clean_response = "An error occurred while generating the response. Please check the logs." # Update the history history.append((user_input, clean_response)) # <-- FIX 3 & 4: Return values to update state, clear image box, and clear text box return history, history, None, "" # Create the Gradio Interface using Blocks for more control with gr.Blocks(theme=gr.themes.Soft(), css="footer {display: none !important}") as demo: gr.Markdown( """ # AI Symptom Checker powered by MedGemma Describe your symptoms in the text box below. You can also upload an image (e.g., a skin rash). The AI assistant will ask clarifying questions before suggesting possible conditions and an action plan. """ ) # Chatbot component to display the conversation chatbot = gr.Chatbot(label="Conversation", height=500, avatar_images=("user.png", "bot.png")) # Added avatars for fun # State to store the conversation history chat_history = gr.State([]) # <-- FIX 3: This state will now be used correctly with gr.Row(): # Image input image_box = gr.Image(type="pil", label="Upload Image of Symptom (Optional)") with gr.Row(): # Text input text_box = gr.Textbox( label="Describe your symptoms...", placeholder="e.g., I have a rash on my arm. It's red and itchy.", scale=4, # Make the textbox wider ) # Submit button submit_btn = gr.Button("Send", variant="primary", scale=1) # Function to clear all inputs def clear_all(): return [], [], None, "" # <-- FIX 3: Correctly clear the state and chatbot # Clear button clear_btn = gr.Button("Start New Conversation") # <-- FIX 3: The outputs list now correctly targets the state clear_btn.click(clear_all, outputs=[chatbot, chat_history, image_box, text_box], queue=False) # Define what happens when the user submits submit_btn.click( fn=symptom_checker_chat, # <-- FIX 3 & 4: Corrected inputs and outputs inputs=[text_box, chat_history, image_box], outputs=[chatbot, chat_history, image_box, text_box] ) # Define what happens when the user just presses Enter in the textbox text_box.submit( fn=symptom_checker_chat, # <-- FIX 3 & 4: Corrected inputs and outputs inputs=[text_box, chat_history, image_box], outputs=[chatbot, chat_history, image_box, text_box] ) # Launch the Gradio app if __name__ == "__main__": demo.launch(debug=True) # Debug mode for more detailed logs