hponepyae's picture
Update app.py
2166c8b verified
raw
history blame
6.04 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import spaces
# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")
# Removed debugging prints for HF_TOKEN for brevity, but keep them if still debugging token issues
# print(f"HF_TOKEN set: {bool(hf_token)}")
# if hf_token:
# print(f"HF_TOKEN length: {len(hf_token)}")
# else:
# print("HF_TOKEN is not set!")
model_id = "google/medgemma-4b-it"
if torch.cuda.is_available():
dtype = torch.bfloat16
# print("CUDA is available. Using bfloat16.")
else:
dtype = torch.float32
# print("CUDA not available. Using float32 (model will load on CPU if device_map='auto').")
model_loaded = False
try:
# print(f"Attempting to load model: {model_id}")
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
# print("Model loaded successfully!")
# print(f"Model device: {model.device}")
except Exception as e:
print(f"CRITICAL ERROR loading model: {e}")
model_loaded = False
@spaces.GPU
def get_clinical_code(clinical_text):
# print(f"\n--- Entering get_clinical_code for: '{clinical_text[:50]}...' ---")
if not model_loaded:
# print("Model not loaded, returning error.")
return "Error: The model could not be loaded. Please check the logs."
if not clinical_text:
# print("Empty clinical text provided.")
return "Please enter some clinical text."
# Construct the full prompt exactly as it will be fed to the model
# Note the specific newline characters and spacing
user_prompt_section = f"""<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>"""
model_start_sequence = "<start_of_turn>model"
# This is the complete prompt string that will be tokenized
full_input_prompt = f"{user_prompt_section}\n{model_start_sequence}\n"
# print(f"Full input prompt length (chars): {len(full_input_prompt)}")
input_ids = tokenizer(full_input_prompt, return_tensors="pt").to(model.device)
# print(f"Input_ids shape: {input_ids.input_ids.shape}")
# print(f"Number of input tokens: {input_ids.input_ids.shape[1]}")
# Generate the output from the model
outputs = model.generate(
**input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7,
pad_token_id=tokenizer.eos_token_id # Important for handling padding in generation
)
# print(f"Outputs shape: {outputs.shape}")
# print(f"Total tokens generated (including prompt): {outputs.shape[1]}")
# Decode the entire generated output first
full_decoded_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# print(f"Full decoded response: '{full_decoded_response}'")
# Find the exact start of the model's intended response
# We look for the model's turn marker in the *decoded* string.
# This is more robust against the model regenerating part of the prompt.
response_start_marker = f"{model_start_sequence}\n" # This is the expected start of the model's turn in the decoded string
start_index = full_decoded_response.find(response_start_marker)
if start_index != -1:
# If the marker is found, slice everything after it
clean_response = full_decoded_response[start_index + len(response_start_marker):].strip()
else:
# Fallback: If the marker isn't found (highly unlikely for Gemma models),
# try the previous slicing approach or return the full decoded output with a warning.
# For robustness, we'll try to remove the original input prompt if the marker is missing.
# This part is a safety net.
clean_response = full_decoded_response.replace(user_prompt_section, "").replace(model_start_sequence, "").strip()
print("Warning: Model start marker not found in decoded response. Falling back to alternative cleaning.")
# print(f"Final response (after cleaning): '{clean_response}'")
# print("--- Exiting get_clinical_code ---")
return clean_response
# Gradio Interface (rest remains the same as your original, no changes needed here)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Clinical Code Generator with Google MedGemma
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
*Disclaimer: This is a demonstration and not for professional medical use.*
"""
)
with gr.Row():
input_text = gr.Textbox(
label="Unstructured Clinical Text",
placeholder="e.g., Patient presents with a severe headache and photophobia...",
lines=10
)
output_text = gr.Textbox(
label="Suggested Clinical Code (ICD-10)",
interactive=False,
lines=5
)
submit_button = gr.Button("Get Clinical Code", variant="primary")
submit_button.click(
fn=get_clinical_code,
inputs=input_text,
outputs=output_text
)
gr.Examples(
examples=[
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
],
inputs=input_text,
outputs=output_text,
fn=get_clinical_code
)
if __name__ == "__main__":
demo.launch()