hponepyae's picture
Update app.py
d513773 verified
raw
history blame
4.87 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import spaces # <--- Add this import!
# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")
# Initialize the tokenizer and model
# We are now using MedGemma, a 4 billion parameter instruction-tuned model
# specialized for the medical domain.
model_id = "google/medgemma-4b-it"
# Check for GPU availability and set the data type accordingly
if torch.cuda.is_available():
dtype = torch.bfloat16
else:
dtype = torch.float32
# Load the tokenizer and model from Hugging Face
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
model_loaded = False
# This is the core function that will take the clinical text and return a code
@spaces.GPU # <--- Add this decorator!
def get_clinical_code(clinical_text):
"""
Generates a clinical code from unstructured clinical text using the MedGemma model.
"""
if not model_loaded:
return "Error: The model could not be loaded. Please check the logs."
if not clinical_text:
return "Please enter some clinical text."
# This is our prompt template. It's designed to guide the model
# to perform the specific task of clinical coding.
# We are asking for an ICD-10 code, which is a common standard.
prompt = f"""<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>
<start_of_turn>model
"""
# Prepare the input for the model
# It's good practice to ensure the input_ids are on the correct device.
# model.device will give you the device where the model currently resides (GPU if available).
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the output from the model
# We are using a max length of 256 tokens which should be sufficient
# for a code and a short description.
outputs = model.generate(
**input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7, # A lower temperature makes the output more deterministic
)
# Decode only the newly generated tokens
# We slice the outputs tensor to exclude the input_ids (the prompt)
generated_tokens = outputs[0, input_ids.input_ids.shape[1]:]
response = tokenizer.decode(generated_tokens, skip_special_tokens=True)
# The prompt explicitly asks for "only the ICD-10 code and a brief description."
# So, we expect the model to start directly with the code and description.
# No further slicing based on "<start_of_turn>model" should be needed for *this* part.
# We can just return the decoded response.
return response.strip()
# Create the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Clinical Code Generator with Google MedGemma
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
*Disclaimer: This is a demonstration and not for professional medical use.*
"""
)
with gr.Row():
# Input Textbox
input_text = gr.Textbox(
label="Unstructured Clinical Text",
placeholder="e.g., Patient presents with a severe headache and photophobia...",
lines=10
)
# Output Textbox
output_text = gr.Textbox(
label="Suggested Clinical Code (ICD-10)",
interactive=False,
lines=5
)
# Submit Button
submit_button = gr.Button("Get Clinical Code", variant="primary")
# Connect the button to the function
submit_button.click(
fn=get_clinical_code,
inputs=input_text,
outputs=output_text
)
gr.Examples(
examples=[
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
],
inputs=input_text,
outputs=output_text,
fn=get_clinical_code # Note: gr.Examples also calls the function, so it will trigger GPU
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()