File size: 4,450 Bytes
83ff66a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os

# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")

# Initialize the tokenizer and model
# We are now using MedGemma, a 4 billion parameter instruction-tuned model
# specialized for the medical domain.
model_id = "google/medgemma-4b-it"

# Check for GPU availability and set the data type accordingly
if torch.cuda.is_available():
    dtype = torch.bfloat16
else:
    dtype = torch.float32

# Load the tokenizer and model from Hugging Face
try:
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        torch_dtype=dtype,
        device_map="auto",
    )
    model_loaded = True
except Exception as e:
    print(f"Error loading model: {e}")
    model_loaded = False

# This is the core function that will take the clinical text and return a code
def get_clinical_code(clinical_text):
    """
    Generates a clinical code from unstructured clinical text using the MedGemma model.
    """
    if not model_loaded:
        return "Error: The model could not be loaded. Please check the logs."
        
    if not clinical_text:
        return "Please enter some clinical text."

    # This is our prompt template. It's designed to guide the model
    # to perform the specific task of clinical coding.
    # We are asking for an ICD-10 code, which is a common standard.
    prompt = f"""
    <start_of_turn>user
    You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.

    Clinical Text: "{clinical_text}"

    Provide the ICD-10 code and a brief description.
    <end_of_turn>
    <start_of_turn>model
    """

    # Prepare the input for the model
    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Generate the output from the model
    # We are using a max length of 256 tokens which should be sufficient
    # for a code and a short description.
    outputs = model.generate(
        **input_ids,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7, # A lower temperature makes the output more deterministic
    )
    
    # Decode the output and clean it up
    response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    
    # Extract the relevant part of the response
    # The model will output the prompt as well, so we need to remove it.
    model_response_start = response.find("<start_of_turn>model") + len("<start_of_turn>model")
    clean_response = response[model_response_start:].strip()
    
    return clean_response

# Create the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Clinical Code Generator with Google MedGemma
        Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
        *Disclaimer: This is a demonstration and not for professional medical use.*
        """
    )
    
    with gr.Row():
        # Input Textbox
        input_text = gr.Textbox(
            label="Unstructured Clinical Text",
            placeholder="e.g., Patient presents with a severe headache and photophobia...",
            lines=10
        )
        
        # Output Textbox
        output_text = gr.Textbox(
            label="Suggested Clinical Code (ICD-10)",
            interactive=False,
            lines=5
        )

    # Submit Button
    submit_button = gr.Button("Get Clinical Code", variant="primary")

    # Connect the button to the function
    submit_button.click(
        fn=get_clinical_code,
        inputs=input_text,
        outputs=output_text
    )
    
    gr.Examples(
        examples=[
            ["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
            ["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
            ["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
        ],
        inputs=input_text,
        outputs=output_text,
        fn=get_clinical_code
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()