Spaces:
Running
Running
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
import spaces | |
# Get the Hugging Face token from the environment variables | |
hf_token = os.environ.get("HF_TOKEN") | |
# Removed debugging prints for HF_TOKEN for brevity, but keep them if still debugging token issues | |
# print(f"HF_TOKEN set: {bool(hf_token)}") | |
# if hf_token: | |
# print(f"HF_TOKEN length: {len(hf_token)}") | |
# else: | |
# print("HF_TOKEN is not set!") | |
model_id = "google/medgemma-4b-it" | |
if torch.cuda.is_available(): | |
dtype = torch.bfloat16 | |
# print("CUDA is available. Using bfloat16.") | |
else: | |
dtype = torch.float32 | |
# print("CUDA not available. Using float32 (model will load on CPU if device_map='auto').") | |
model_loaded = False | |
try: | |
# print(f"Attempting to load model: {model_id}") | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
torch_dtype=dtype, | |
device_map="auto", | |
) | |
model_loaded = True | |
# print("Model loaded successfully!") | |
# print(f"Model device: {model.device}") | |
except Exception as e: | |
print(f"CRITICAL ERROR loading model: {e}") | |
model_loaded = False | |
def get_clinical_code(clinical_text): | |
# print(f"\n--- Entering get_clinical_code for: '{clinical_text[:50]}...' ---") | |
if not model_loaded: | |
# print("Model not loaded, returning error.") | |
return "Error: The model could not be loaded. Please check the logs." | |
if not clinical_text: | |
# print("Empty clinical text provided.") | |
return "Please enter some clinical text." | |
# Construct the full prompt exactly as it will be fed to the model | |
# Note the specific newline characters and spacing | |
user_prompt_section = f"""<start_of_turn>user | |
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description. | |
Clinical Text: "{clinical_text}" | |
Provide the ICD-10 code and a brief description. | |
<end_of_turn>""" | |
model_start_sequence = "<start_of_turn>model" | |
# This is the complete prompt string that will be tokenized | |
full_input_prompt = f"{user_prompt_section}\n{model_start_sequence}\n" | |
# print(f"Full input prompt length (chars): {len(full_input_prompt)}") | |
input_ids = tokenizer(full_input_prompt, return_tensors="pt").to(model.device) | |
# print(f"Input_ids shape: {input_ids.input_ids.shape}") | |
# print(f"Number of input tokens: {input_ids.input_ids.shape[1]}") | |
# Generate the output from the model | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, | |
pad_token_id=tokenizer.eos_token_id # Important for handling padding in generation | |
) | |
# print(f"Outputs shape: {outputs.shape}") | |
# print(f"Total tokens generated (including prompt): {outputs.shape[1]}") | |
# Decode the entire generated output first | |
full_decoded_response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# print(f"Full decoded response: '{full_decoded_response}'") | |
# Find the exact start of the model's intended response | |
# We look for the model's turn marker in the *decoded* string. | |
# This is more robust against the model regenerating part of the prompt. | |
response_start_marker = f"{model_start_sequence}\n" # This is the expected start of the model's turn in the decoded string | |
start_index = full_decoded_response.find(response_start_marker) | |
if start_index != -1: | |
# If the marker is found, slice everything after it | |
clean_response = full_decoded_response[start_index + len(response_start_marker):].strip() | |
else: | |
# Fallback: If the marker isn't found (highly unlikely for Gemma models), | |
# try the previous slicing approach or return the full decoded output with a warning. | |
# For robustness, we'll try to remove the original input prompt if the marker is missing. | |
# This part is a safety net. | |
clean_response = full_decoded_response.replace(user_prompt_section, "").replace(model_start_sequence, "").strip() | |
print("Warning: Model start marker not found in decoded response. Falling back to alternative cleaning.") | |
# print(f"Final response (after cleaning): '{clean_response}'") | |
# print("--- Exiting get_clinical_code ---") | |
return clean_response | |
# Gradio Interface (rest remains the same as your original, no changes needed here) | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Clinical Code Generator with Google MedGemma | |
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code. | |
*Disclaimer: This is a demonstration and not for professional medical use.* | |
""" | |
) | |
with gr.Row(): | |
input_text = gr.Textbox( | |
label="Unstructured Clinical Text", | |
placeholder="e.g., Patient presents with a severe headache and photophobia...", | |
lines=10 | |
) | |
output_text = gr.Textbox( | |
label="Suggested Clinical Code (ICD-10)", | |
interactive=False, | |
lines=5 | |
) | |
submit_button = gr.Button("Get Clinical Code", variant="primary") | |
submit_button.click( | |
fn=get_clinical_code, | |
inputs=input_text, | |
outputs=output_text | |
) | |
gr.Examples( | |
examples=[ | |
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."], | |
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."], | |
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."] | |
], | |
inputs=input_text, | |
outputs=output_text, | |
fn=get_clinical_code | |
) | |
if __name__ == "__main__": | |
demo.launch() |