File size: 6,043 Bytes
83ff66a
 
 
 
2166c8b
83ff66a
 
 
2166c8b
 
 
 
 
 
83ff66a
 
 
 
 
2166c8b
83ff66a
 
2166c8b
83ff66a
2166c8b
83ff66a
2166c8b
83ff66a
 
 
 
 
 
 
 
2166c8b
 
83ff66a
2166c8b
83ff66a
 
2166c8b
83ff66a
2166c8b
83ff66a
2166c8b
83ff66a
 
2166c8b
83ff66a
 
2166c8b
 
 
d513773
 
 
2166c8b
 
 
 
 
 
 
 
 
 
 
 
83ff66a
 
 
 
 
 
2166c8b
 
83ff66a
2166c8b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8ec5a02
 
2166c8b
 
 
83ff66a
2166c8b
83ff66a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2166c8b
83ff66a
 
 
8ec5a02
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
import gradio as gr
import torch
from transformers import AutoTokenizer, AutoModelForCausalLM
import os
import spaces

# Get the Hugging Face token from the environment variables
hf_token = os.environ.get("HF_TOKEN")
# Removed debugging prints for HF_TOKEN for brevity, but keep them if still debugging token issues
# print(f"HF_TOKEN set: {bool(hf_token)}")
# if hf_token:
#     print(f"HF_TOKEN length: {len(hf_token)}")
# else:
#     print("HF_TOKEN is not set!")

model_id = "google/medgemma-4b-it"

if torch.cuda.is_available():
    dtype = torch.bfloat16
    # print("CUDA is available. Using bfloat16.")
else:
    dtype = torch.float32
    # print("CUDA not available. Using float32 (model will load on CPU if device_map='auto').")

model_loaded = False
try:
    # print(f"Attempting to load model: {model_id}")
    tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token)
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        token=hf_token,
        torch_dtype=dtype,
        device_map="auto",
    )
    model_loaded = True
    # print("Model loaded successfully!")
    # print(f"Model device: {model.device}")
except Exception as e:
    print(f"CRITICAL ERROR loading model: {e}")
    model_loaded = False

@spaces.GPU
def get_clinical_code(clinical_text):
    # print(f"\n--- Entering get_clinical_code for: '{clinical_text[:50]}...' ---")
    if not model_loaded:
        # print("Model not loaded, returning error.")
        return "Error: The model could not be loaded. Please check the logs."
    if not clinical_text:
        # print("Empty clinical text provided.")
        return "Please enter some clinical text."

    # Construct the full prompt exactly as it will be fed to the model
    # Note the specific newline characters and spacing
    user_prompt_section = f"""<start_of_turn>user
You are an expert medical coder. Your task is to analyze the following clinical text and determine the most appropriate ICD-10 code. Provide only the ICD-10 code and a brief description.
Clinical Text: "{clinical_text}"
Provide the ICD-10 code and a brief description.
<end_of_turn>"""

    model_start_sequence = "<start_of_turn>model"

    # This is the complete prompt string that will be tokenized
    full_input_prompt = f"{user_prompt_section}\n{model_start_sequence}\n"

    # print(f"Full input prompt length (chars): {len(full_input_prompt)}")

    input_ids = tokenizer(full_input_prompt, return_tensors="pt").to(model.device)
    # print(f"Input_ids shape: {input_ids.input_ids.shape}")
    # print(f"Number of input tokens: {input_ids.input_ids.shape[1]}")

    # Generate the output from the model
    outputs = model.generate(
        **input_ids,
        max_new_tokens=256,
        do_sample=True,
        temperature=0.7,
        pad_token_id=tokenizer.eos_token_id # Important for handling padding in generation
    )
    # print(f"Outputs shape: {outputs.shape}")
    # print(f"Total tokens generated (including prompt): {outputs.shape[1]}")

    # Decode the entire generated output first
    full_decoded_response = tokenizer.decode(outputs[0], skip_special_tokens=True)
    # print(f"Full decoded response: '{full_decoded_response}'")

    # Find the exact start of the model's intended response
    # We look for the model's turn marker in the *decoded* string.
    # This is more robust against the model regenerating part of the prompt.
    response_start_marker = f"{model_start_sequence}\n" # This is the expected start of the model's turn in the decoded string
    
    start_index = full_decoded_response.find(response_start_marker)

    if start_index != -1:
        # If the marker is found, slice everything after it
        clean_response = full_decoded_response[start_index + len(response_start_marker):].strip()
    else:
        # Fallback: If the marker isn't found (highly unlikely for Gemma models),
        # try the previous slicing approach or return the full decoded output with a warning.
        # For robustness, we'll try to remove the original input prompt if the marker is missing.
        # This part is a safety net.
        clean_response = full_decoded_response.replace(user_prompt_section, "").replace(model_start_sequence, "").strip()
        print("Warning: Model start marker not found in decoded response. Falling back to alternative cleaning.")


    # print(f"Final response (after cleaning): '{clean_response}'")
    # print("--- Exiting get_clinical_code ---")
    return clean_response

# Gradio Interface (rest remains the same as your original, no changes needed here)
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown(
        """
        # Clinical Code Generator with Google MedGemma
        Enter a piece of unstructured clinical text below, and the app will suggest an ICD-10 clinical code.
        *Disclaimer: This is a demonstration and not for professional medical use.*
        """
    )
    with gr.Row():
        input_text = gr.Textbox(
            label="Unstructured Clinical Text",
            placeholder="e.g., Patient presents with a severe headache and photophobia...",
            lines=10
        )
        output_text = gr.Textbox(
            label="Suggested Clinical Code (ICD-10)",
            interactive=False,
            lines=5
        )
    submit_button = gr.Button("Get Clinical Code", variant="primary")
    submit_button.click(
        fn=get_clinical_code,
        inputs=input_text,
        outputs=output_text
    )
    gr.Examples(
        examples=[
            ["The patient complains of a persistent cough and fever for the past three days. Chest X-ray shows signs of pneumonia."],
            ["45-year-old male with a history of hypertension presents with chest pain radiating to the left arm."],
            ["The patient has a history of type 2 diabetes and is here for a routine check-up. Blood sugar levels are elevated."]
        ],
        inputs=input_text,
        outputs=output_text,
        fn=get_clinical_code
    )

if __name__ == "__main__":
    demo.launch()