Spaces:
Sleeping
Sleeping
import gradio as gr | |
import torch | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
import os | |
# Get the Hugging Face token from the environment variables | |
hf_token = os.environ.get("HF_TOKEN") | |
# Initialize the tokenizer and model | |
# We are now using MedGemma, a 4 billion parameter instruction-tuned model | |
# specialized for the medical domain. | |
model_id = "google/medgemma-4b-it" | |
# Check for GPU availability and set the data type accordingly | |
if torch.cuda.is_available(): | |
dtype = torch.bfloat16 | |
else: | |
dtype = torch.float32 | |
# Load the tokenizer and model from Hugging Face | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
torch_dtype=dtype, | |
device_map="auto", | |
) | |
model_loaded = True | |
except Exception as e: | |
print(f"Error loading model: {e}") | |
model_loaded = False | |
# This is the core function that will take the clinical text and return a code | |
def get_clinical_code(clinical_text): | |
""" | |
Generates a clinical code from unstructured clinical text using the MedGemma model. | |
""" | |
if not model_loaded: | |
return "Error: The model could not be loaded. Please check the logs." | |
if not clinical_text: | |
return "Please enter some clinical text." | |
# This is our prompt template. It's designed to guide the model | |
# to perform the specific task of clinical coding. | |
# We are asking for an ICD-10 code, which is a common standard. | |
prompt = f""" | |
<start_of_turn>user | |
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description. | |
Clinical Text: "{clinical_text}" | |
Provide the ICD-10 code and a brief description. | |
<end_of_turn> | |
<start_of_turn>model | |
""" | |
# Prepare the input for the model | |
input_ids = tokenizer(prompt, return_tensors="pt").to(model.device) | |
# Generate the output from the model | |
# We are using a max length of 256 tokens which should be sufficient | |
# for a code and a short description. | |
outputs = model.generate( | |
**input_ids, | |
max_new_tokens=256, | |
do_sample=True, | |
temperature=0.7, # A lower temperature makes the output more deterministic | |
) | |
# Decode the output and clean it up | |
response = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Extract the relevant part of the response | |
# The model will output the prompt as well, so we need to remove it. | |
model_response_start = response.find("<start_of_turn>model") + len("<start_of_turn>model") | |
clean_response = response[model_response_start:].strip() | |
return clean_response | |
# Create the Gradio Interface | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
""" | |
# Clinical Code Generator with Google MedGemma | |
Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code. | |
*Disclaimer: This is a demonstration and not for professional medical use.* | |
""" | |
) | |
with gr.Row(): | |
# Input Textbox | |
input_text = gr.Textbox( | |
label="Unstructured Clinical Text", | |
placeholder="e.g., Patient presents with a severe headache and photophobia...", | |
lines=10 | |
) | |
# Output Textbox | |
output_text = gr.Textbox( | |
label="Suggested Clinical Code (ICD-10)", | |
interactive=False, | |
lines=5 | |
) | |
# Submit Button | |
submit_button = gr.Button("Get Clinical Code", variant="primary") | |
# Connect the button to the function | |
submit_button.click( | |
fn=get_clinical_code, | |
inputs=input_text, | |
outputs=output_text | |
) | |
gr.Examples( | |
examples=[ | |
["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."], | |
["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."], | |
["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."] | |
], | |
inputs=input_text, | |
outputs=output_text, | |
fn=get_clinical_code | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() | |