hponepyae's picture
Create app.py
83ff66a verified
raw
history blame
4.45 kB
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")
# Initialize the tokenizer and model
# We are now using MedGemma, a 4 billion parameter instruction-tuned model
# specialized for the medical domain.
model_id = "google/medgemma-4b-it"
# Check for GPU availability and set the data type accordingly
if torch.cuda.is_available():
dtype = torch.bfloat16
else:
dtype = torch.float32
# Load the tokenizer and model from Hugging Face
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
model_loaded = False
# This is the core function that will take the clinical text and return a code
def get_clinical_code(clinical_text):
"""
Generates a clinical code from unstructured clinical text using the MedGemma model.
"""
if not model_loaded:
return "Error: The model could not be loaded. Please check the logs."
if not clinical_text:
return "Please enter some clinical text."
# This is our prompt template. It's designed to guide the model
# to perform the specific task of clinical coding.
# We are asking for an ICD-10 code, which is a common standard.
prompt = f"""
<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>
<start_of_turn>model
"""
# Prepare the input for the model
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the output from the model
# We are using a max length of 256 tokens which should be sufficient
# for a code and a short description.
outputs = model.generate(
**input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7, # A lower temperature makes the output more deterministic
)
# Decode the output and clean it up
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the relevant part of the response
# The model will output the prompt as well, so we need to remove it.
model_response_start = response.find("<start_of_turn>model") + len("<start_of_turn>model")
clean_response = response[model_response_start:].strip()
return clean_response
# Create the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Clinical Code Generator with Google MedGemma
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
*Disclaimer: This is a demonstration and not for professional medical use.*
"""
)
with gr.Row():
# Input Textbox
input_text = gr.Textbox(
label="Unstructured Clinical Text",
placeholder="e.g., Patient presents with a severe headache and photophobia...",
lines=10
)
# Output Textbox
output_text = gr.Textbox(
label="Suggested Clinical Code (ICD-10)",
interactive=False,
lines=5
)
# Submit Button
submit_button = gr.Button("Get Clinical Code", variant="primary")
# Connect the button to the function
submit_button.click(
fn=get_clinical_code,
inputs=input_text,
outputs=output_text
)
gr.Examples(
examples=[
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
],
inputs=input_text,
outputs=output_text,
fn=get_clinical_code
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()