Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
import spaces | |
# Get the Hugging Face token from the environment variables | |
hf_token = os.environ.get("HF_TOKEN") | |
# Removed debugging prints for HF_TOKEN for brevity, but keep them if still debugging token issues | |
# print(f"HF_TOKEN set: {bool(hf_token)}") | |
# if hf_token: | |
# print(f"HF_TOKEN length: {len(hf_token)}") | |
# else: | |
# print("HF_TOKEN is not set!") | |
model_id = "google/medgemma-4b-it" | |
if torch.cuda.is_available(): | |
dtype = torch.bfloat16 | |
# print("CUDA is available. Using bfloat16.") | |
else: | |
dtype = torch.float32 | |
# print("CUDA not available. Using float32 (model will load on CPU if device_map='auto').") | |
model_loaded = False | |
try: | |
# print(f"Attempting to load model: {model_id}") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
torch_dtype=dtype, | |
device_map="auto", | |
) | |
model_loaded = True | |
# print("Model loaded successfully!") | |
# print(f"Model device: {model.device}") | |
except Exception as e: | |
print(f"CRITICAL ERROR loading model: {e}") | |
model_loaded = False | |
def get_clinical_code(clinical_text): | |
# print(f"\n--- Entering get_clinical_code for: '{clinical_text[:50]}...' ---") | |
if not model_loaded: | |
# print("Model not loaded, returning error.") | |
return "Error: The model could not be loaded. Please check the logs." | |
if not clinical_text: | |
# print("Empty clinical text provided.") | |
return "Please enter some clinical text." | |
# Construct the full prompt exactly as it will be fed to the model | |
# Note the specific newline characters and spacing | |
user_prompt_section = f"""<start_of_turn>user | |
You are an expert medical coder. Your task is to analyze the following clinical text and provide the diagnosis in standard medical terminology. Specify each diagnosis as main diagnosis or underlying chronic conditions. | |
Clinical Text: "{clinical_text}" | |
Provide the ICD-10 code and a brief description. | |
<end_of_turn>""" | |
model_start_sequence = "<start_of_turn>model" | |
# This is the complete prompt string that will be tokenized | |
full_input_prompt = f"{user_prompt_section}\n{model_start_sequence}\n" | |
# print(f"Full input prompt length (chars): {len(full_input_prompt)}") | |
input_ids = tokenizer(full_input_prompt, return_tensors="pt").to(model.device) | |
# print(f"Input_ids shape: {input_ids.input_ids.shape}") | |
# print(f"Number of input tokens: {input_ids.input_ids.shape[1]}") | |
# Generate the output from the model | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id # Important for handling padding in generation | |
) | |
# print(f"Outputs shape: {outputs.shape}") | |
# print(f"Total tokens generated (including prompt): {outputs.shape[1]}") | |
# Decode the entire generated output first | |
full_decoded_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# print(f"Full decoded response: '{full_decoded_response}'") | |
# Find the exact start of the model's intended response | |
# We look for the model's turn marker in the *decoded* string. | |
# This is more robust against the model regenerating part of the prompt. | |
response_start_marker = f"{model_start_sequence}\n" # This is the expected start of the model's turn in the decoded string | |
start_index = full_decoded_response.find(response_start_marker) | |
if start_index != -1: | |
# If the marker is found, slice everything after it | |
clean_response = full_decoded_response[start_index + len(response_start_marker):].strip() | |
else: | |
# Fallback: If the marker isn't found (highly unlikely for Gemma models), | |
# try the previous slicing approach or return the full decoded output with a warning. | |
# For robustness, we'll try to remove the original input prompt if the marker is missing. | |
# This part is a safety net. | |
clean_response = full_decoded_response.replace(user_prompt_section, "").replace(model_start_sequence, "").strip() | |
print("Warning: Model start marker not found in decoded response. Falling back to alternative cleaning.") | |
# print(f"Final response (after cleaning): '{clean_response}'") | |
# print("--- Exiting get_clinical_code ---") | |
return clean_response | |
# Gradio Interface (rest remains the same as your original, no changes needed here) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Clinical Code Generator with Google MedGemma | |
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code. | |
*Disclaimer: This is a demonstration and not for professional medical use.* | |
""" | |
) | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Unstructured Clinical Text", | |
placeholder="e.g., Patient presents with a severe headache and photophobia...", | |
lines=10 | |
) | |
output_text = gr.Textbox( | |
label="Suggested Clinical Code (ICD-10)", | |
interactive=False, | |
lines=5 | |
) | |
submit_button = gr.Button("Get Clinical Code", variant="primary") | |
submit_button.click( | |
fn=get_clinical_code, | |
inputs=input_text, | |
outputs=output_text | |
) | |
gr.Examples( | |
examples=[ | |
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."], | |
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."], | |
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."] | |
], | |
inputs=input_text, | |
outputs=output_text, | |
fn=get_clinical_code | |
) | |
if __name__ == "__main__": | |
demo.launch() |