File size: 4,868 Bytes
83ff66a
 
 
 
8ec5a02
83ff66a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec5a02
83ff66a
 
 
 
 
 
 
 
 
 
 
 
d513773
 
 
 
 
 
 
83ff66a
8ec5a02
 
83ff66a
 
 
 
 
 
 
 
 
 
 
8ec5a02
d513773
 
 
 
8ec5a02
d513773
 
 
 
 
83ff66a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec5a02
83ff66a
 
 
 
8ec5a02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import spaces # <--- Add this import!

# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")

# Initialize the tokenizer and model
# We are now using MedGemma, a 4 billion parameter instruction-tuned model
# specialized for the medical domain.
model_id = "google/medgemma-4b-it"

# Check for GPU availability and set the data type accordingly
if torch.cuda.is_available():
    dtype = torch.bfloat16
else:
    dtype = torch.float32

# Load the tokenizer and model from Hugging Face
try:
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        torch_dtype=dtype,
        device_map="auto",
    )
    model_loaded = True
except Exception as e:
    print(f"Error loading model: {e}")
    model_loaded = False

# This is the core function that will take the clinical text and return a code
@spaces.GPU # <--- Add this decorator!
def get_clinical_code(clinical_text):
    """
    Generates a clinical code from unstructured clinical text using the MedGemma model.
    """
    if not model_loaded:
        return "Error: The model could not be loaded. Please check the logs."
    if not clinical_text:
        return "Please enter some clinical text."

    # This is our prompt template. It's designed to guide the model
    # to perform the specific task of clinical coding.
    # We are asking for an ICD-10 code, which is a common standard.
    prompt = f"""<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>
<start_of_turn>model
"""
    # Prepare the input for the model
    # It's good practice to ensure the input_ids are on the correct device.
    # model.device will give you the device where the model currently resides (GPU if available).
    input_ids = tokenizer(prompt, return_tensors="pt").to(model.device)

    # Generate the output from the model
    # We are using a max length of 256 tokens which should be sufficient
    # for a code and a short description.
    outputs = model.generate(
        **input_ids,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7, # A lower temperature makes the output more deterministic
    )

    # Decode only the newly generated tokens
    # We slice the outputs tensor to exclude the input_ids (the prompt)
    generated_tokens = outputs[0, input_ids.input_ids.shape[1]:]
    response = tokenizer.decode(generated_tokens, skip_special_tokens=True)

    # The prompt explicitly asks for "only the ICD-10 code and a brief description."
    # So, we expect the model to start directly with the code and description.
    # No further slicing based on "<start_of_turn>model" should be needed for *this* part.
    # We can just return the decoded response.
    return response.strip()

# Create the Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Clinical Code Generator with Google MedGemma
        Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
        *Disclaimer: This is a demonstration and not for professional medical use.*
        """
    )
    with gr.Row():
        # Input Textbox
        input_text = gr.Textbox(
            label="Unstructured Clinical Text",
            placeholder="e.g., Patient presents with a severe headache and photophobia...",
            lines=10
        )
        # Output Textbox
        output_text = gr.Textbox(
            label="Suggested Clinical Code (ICD-10)",
            interactive=False,
            lines=5
        )
    # Submit Button
    submit_button = gr.Button("Get Clinical Code", variant="primary")
    # Connect the button to the function
    submit_button.click(
        fn=get_clinical_code,
        inputs=input_text,
        outputs=output_text
    )
    gr.Examples(
        examples=[
            ["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
            ["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
            ["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
        ],
        inputs=input_text,
        outputs=output_text,
        fn=get_clinical_code # Note: gr.Examples also calls the function, so it will trigger GPU
    )

# Launch the Gradio app
if __name__ == "__main__":
    demo.launch()