Spaces:
Sleeping
Sleeping
File size: 4,450 Bytes
83ff66a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")
# Initialize the tokenizer and model
# We are now using MedGemma, a 4 billion parameter instruction-tuned model
# specialized for the medical domain.
model_id = "google/medgemma-4b-it"
# Check for GPU availability and set the data type accordingly
if torch.cuda.is_available():
dtype = torch.bfloat16
else:
dtype = torch.float32
# Load the tokenizer and model from Hugging Face
try:
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
model = AutoModelForCausalLM.from_pretrained(
model_id,
token=hf_token,
torch_dtype=dtype,
device_map="auto",
)
model_loaded = True
except Exception as e:
print(f"Error loading model: {e}")
model_loaded = False
# This is the core function that will take the clinical text and return a code
def get_clinical_code(clinical_text):
"""
Generates a clinical code from unstructured clinical text using the MedGemma model.
"""
if not model_loaded:
return "Error: The model could not be loaded. Please check the logs."
if not clinical_text:
return "Please enter some clinical text."
# This is our prompt template. It's designed to guide the model
# to perform the specific task of clinical coding.
# We are asking for an ICD-10 code, which is a common standard.
prompt = f"""
<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>
<start_of_turn>model
"""
# Prepare the input for the model
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)
# Generate the output from the model
# We are using a max length of 256 tokens which should be sufficient
# for a code and a short description.
outputs = model.generate(
**input_ids,
max_new_tokens=256,
do_sample=True,
temperature=0.7, # A lower temperature makes the output more deterministic
)
# Decode the output and clean it up
response = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract the relevant part of the response
# The model will output the prompt as well, so we need to remove it.
model_response_start = response.find("<start_of_turn>model") + len("<start_of_turn>model")
clean_response = response[model_response_start:].strip()
return clean_response
# Create the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
gr.Markdown(
"""
# Clinical Code Generator with Google MedGemma
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
*Disclaimer: This is a demonstration and not for professional medical use.*
"""
)
with gr.Row():
# Input Textbox
input_text = gr.Textbox(
label="Unstructured Clinical Text",
placeholder="e.g., Patient presents with a severe headache and photophobia...",
lines=10
)
# Output Textbox
output_text = gr.Textbox(
label="Suggested Clinical Code (ICD-10)",
interactive=False,
lines=5
)
# Submit Button
submit_button = gr.Button("Get Clinical Code", variant="primary")
# Connect the button to the function
submit_button.click(
fn=get_clinical_code,
inputs=input_text,
outputs=output_text
)
gr.Examples(
examples=[
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
],
inputs=input_text,
outputs=output_text,
fn=get_clinical_code
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()
|