import gradio as gr import torch from transformers import AutoTokenizer, AutoModelForCausalLM import os import spaces # Get the Hugging Face token from the environment variables hf_token = os.environ.get("HF_TOKEN") # Removed debugging prints for HF_TOKEN for brevity, but keep them if still debugging token issues # print(f"HF_TOKEN set: {bool(hf_token)}") # if hf_token: # print(f"HF_TOKEN length: {len(hf_token)}") # else: # print("HF_TOKEN is not set!") model_id = "google/medgemma-4b-it" if torch.cuda.is_available(): dtype = torch.bfloat16 # print("CUDA is available. Using bfloat16.") else: dtype = torch.float32 # print("CUDA not available. Using float32 (model will load on CPU if device_map='auto').") model_loaded = False try: # print(f"Attempting to load model: {model_id}") tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) model = AutoModelForCausalLM.from_pretrained( model_id, token=hf_token, torch_dtype=dtype, device_map="auto", ) model_loaded = True # print("Model loaded successfully!") # print(f"Model device: {model.device}") except Exception as e: print(f"CRITICAL ERROR loading model: {e}") model_loaded = False @spaces.GPU def get_clinical_code(clinical_text): # print(f"\n--- Entering get_clinical_code for: '{clinical_text[:50]}...' ---") if not model_loaded: # print("Model not loaded, returning error.") return "Error: The model could not be loaded. Please check the logs." if not clinical_text: # print("Empty clinical text provided.") return "Please enter some clinical text." # Construct the full prompt exactly as it will be fed to the model # Note the specific newline characters and spacing user_prompt_section = f"""user You are a clinical assistant. You are provided with patient medical records. Analyze the clinical condition and suggest one or more actions missing in the document. Clinical Text: "{clinical_text}" Provide the Diagnosis. """ model_start_sequence = "model" # This is the complete prompt string that will be tokenized full_input_prompt = f"{user_prompt_section}\n{model_start_sequence}\n" # print(f"Full input prompt length (chars): {len(full_input_prompt)}") input_ids = tokenizer(full_input_prompt, return_tensors="pt").to(model.device) # print(f"Input_ids shape: {input_ids.input_ids.shape}") # print(f"Number of input tokens: {input_ids.input_ids.shape[1]}") # Generate the output from the model outputs = model.generate( **input_ids, max_new_tokens=256, do_sample=True, temperature=0.7, pad_token_id=tokenizer.eos_token_id # Important for handling padding in generation ) # print(f"Outputs shape: {outputs.shape}") # print(f"Total tokens generated (including prompt): {outputs.shape[1]}") # Decode the entire generated output first full_decoded_response = tokenizer.decode(outputs[0], skip_special_tokens=True) # print(f"Full decoded response: '{full_decoded_response}'") # Find the exact start of the model's intended response # We look for the model's turn marker in the *decoded* string. # This is more robust against the model regenerating part of the prompt. response_start_marker = f"{model_start_sequence}\n" # This is the expected start of the model's turn in the decoded string start_index = full_decoded_response.find(response_start_marker) if start_index != -1: # If the marker is found, slice everything after it clean_response = full_decoded_response[start_index + len(response_start_marker):].strip() else: # Fallback: If the marker isn't found (highly unlikely for Gemma models), # try the previous slicing approach or return the full decoded output with a warning. # For robustness, we'll try to remove the original input prompt if the marker is missing. # This part is a safety net. clean_response = full_decoded_response.replace(user_prompt_section, "").replace(model_start_sequence, "").strip() print("Warning: Model start marker not found in decoded response. Falling back to alternative cleaning.") # print(f"Final response (after cleaning): '{clean_response}'") # print("--- Exiting get_clinical_code ---") return clean_response # Gradio Interface (rest remains the same as your original, no changes needed here) with gr.Blocks(theme=gr.themes.Soft()) as demo: gr.Markdown( """ # Clinical Code Generator with Google MedGemma Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code. *Disclaimer: This is a demonstration and not for professional medical use.* """ ) with gr.Row(): input_text = gr.Textbox( label="Unstructured Clinical Text", placeholder="e.g., Patient presents with a severe headache and photophobia...", lines=10 ) output_text = gr.Textbox( label="Suggested Clinical Code (ICD-10)", interactive=False, lines=5 ) submit_button = gr.Button("Get Clinical Code", variant="primary") submit_button.click( fn=get_clinical_code, inputs=input_text, outputs=output_text ) gr.Examples( examples=[ ["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."], ["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."], ["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."] ], inputs=input_text, outputs=output_text, fn=get_clinical_code ) if __name__ == "__main__": demo.launch()