Spaces:
Running
Running
File size: 6,043 Bytes
83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b d513773 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 8ec5a02 2166c8b 83ff66a 2166c8b 83ff66a 2166c8b 83ff66a 8ec5a02 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import spaces
# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")
# Removed debugging prints for HF_TOKEN for brevity, but keep them if still debugging token issues
# print(f"HF_TOKEN set: {bool(hf_token)}")
# if hf_token:
# print(f"HF_TOKEN length: {len(hf_token)}")
# else:
# print("HF_TOKEN is not set!")
model_id = "google/medgemma-4b-it"
if torch.cuda.is_available():
dtype = torch.bfloat16
# print("CUDA is available. Using bfloat16.")
else:
dtype = torch.float32
# print("CUDA not available. Using float32 (model will load on CPU if device_map='auto').")
model_loaded = False
try:
# print(f"Attempting to load model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
# print("Model loaded successfully!")
# print(f"Model device: {model.device}")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
model_loaded = False
@spaces.GPU
def get_clinical_code(clinical_text):
# print(f"\n--- Entering get_clinical_code for: '{clinical_text[:50]}...' ---")
if not model_loaded:
# print("Model not loaded, returning error.")
return "Error: The model could not be loaded. Please check the logs."
if not clinical_text:
# print("Empty clinical text provided.")
return "Please enter some clinical text."
# Construct the full prompt exactly as it will be fed to the model
# Note the specific newline characters and spacing
user_prompt_section = f"""<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>"""
model_start_sequence = "<start_of_turn>model"
# This is the complete prompt string that will be tokenized
full_input_prompt = f"{user_prompt_section}\n{model_start_sequence}\n"
# print(f"Full input prompt length (chars): {len(full_input_prompt)}")
input_ids = tokenizer(full_input_prompt, return_tensors="pt").to(model.device)
# print(f"Input_ids shape: {input_ids.input_ids.shape}")
# print(f"Number of input tokens: {input_ids.input_ids.shape[1]}")
# Generate the output from the model
outputs = model.generate(
**input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id # Important for handling padding in generation
)
# print(f"Outputs shape: {outputs.shape}")
# print(f"Total tokens generated (including prompt): {outputs.shape[1]}")
# Decode the entire generated output first
full_decoded_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(f"Full decoded response: '{full_decoded_response}'")
# Find the exact start of the model's intended response
# We look for the model's turn marker in the *decoded* string.
# This is more robust against the model regenerating part of the prompt.
response_start_marker = f"{model_start_sequence}\n" # This is the expected start of the model's turn in the decoded string
start_index = full_decoded_response.find(response_start_marker)
if start_index != -1:
# If the marker is found, slice everything after it
clean_response = full_decoded_response[start_index + len(response_start_marker):].strip()
else:
# Fallback: If the marker isn't found (highly unlikely for Gemma models),
# try the previous slicing approach or return the full decoded output with a warning.
# For robustness, we'll try to remove the original input prompt if the marker is missing.
# This part is a safety net.
clean_response = full_decoded_response.replace(user_prompt_section, "").replace(model_start_sequence, "").strip()
print("Warning: Model start marker not found in decoded response. Falling back to alternative cleaning.")
# print(f"Final response (after cleaning): '{clean_response}'")
# print("--- Exiting get_clinical_code ---")
return clean_response
# Gradio Interface (rest remains the same as your original, no changes needed here)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Clinical Code Generator with Google MedGemma
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
*Disclaimer: This is a demonstration and not for professional medical use.*
"""
)
with gr.Row():
input_text = gr.Textbox(
label="Unstructured Clinical Text",
placeholder="e.g., Patient presents with a severe headache and photophobia...",
lines=10
)
output_text = gr.Textbox(
label="Suggested Clinical Code (ICD-10)",
interactive=False,
lines=5
)
submit_button = gr.Button("Get Clinical Code", variant="primary")
submit_button.click(
fn=get_clinical_code,
inputs=input_text,
outputs=output_text
)
gr.Examples(
examples=[
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
],
inputs=input_text,
outputs=output_text,
fn=get_clinical_code
)
if __name__ == "__main__":
demo.launch() |