total-classifier / modeling.py
ianpan's picture
Upload model
dbdbb0d verified
import cv2
import glob
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from einops import rearrange
from transformers import PreTrainedModel
from timm import create_model
from .configuration import TotalClassifierConfig
from .label2index import label2index
_PYDICOM_AVAILABLE = False
try:
from pydicom import dcmread
_PYDICOM_AVAILABLE = True
except ModuleNotFoundError:
pass
_PANDAS_AVAILABLE = False
try:
import pandas as pd
_PANDAS_AVAILABLE = True
except ModuleNotFoundError:
pass
class RNNHead(nn.Module):
def __init__(
self,
rnn_type: str,
rnn_num_layers: int,
rnn_dropout: float,
feature_dim: int,
linear_dropout: float,
num_classes: int,
):
super().__init__()
self.rnn = getattr(nn, rnn_type)(
input_size=feature_dim,
hidden_size=feature_dim // 2,
num_layers=rnn_num_layers,
dropout=rnn_dropout,
batch_first=True,
bidirectional=True,
)
self.dropout = nn.Dropout(linear_dropout)
self.linear = nn.Linear(feature_dim, num_classes)
@staticmethod
def convert_seq_and_mask_to_packed_sequence(
seq: torch.Tensor, mask: torch.Tensor
) -> tuple[torch.Tensor, torch.Tensor]:
assert seq.shape[0] == mask.shape[0]
lengths = mask.sum(1)
seq = nn.utils.rnn.pack_padded_sequence(
seq, lengths.cpu().int(), batch_first=True, enforce_sorted=False
)
return seq
def forward(
self, x: torch.Tensor, mask: torch.Tensor | None = None
) -> torch.Tensor:
skip = x
if mask is not None:
# convert to PackedSequence
L = x.shape[1]
x = self.convert_seq_and_mask_to_packed_sequence(x, mask)
x, _ = self.rnn(x)
if mask is not None:
# convert back to tensor
x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=L)[0]
x = x + skip
return self.linear(self.dropout(x))
class TotalClassifierModel(PreTrainedModel):
config_class = TotalClassifierConfig
def __init__(self, config):
super().__init__(config)
self.image_size = config.image_size
self.backbone = create_model(
model_name=config.backbone,
pretrained=False,
num_classes=0,
global_pool="",
features_only=True,
in_chans=config.in_chans,
)
self.cnn_dropout = nn.Dropout(p=config.cnn_dropout)
self.head = RNNHead(
rnn_type=config.rnn_type,
rnn_num_layers=config.rnn_num_layers,
rnn_dropout=config.rnn_dropout,
feature_dim=config.feature_dim,
linear_dropout=config.linear_dropout,
num_classes=config.num_classes,
)
self.label2index = label2index
self.index2label = {v: k for k, v in self.label2index.items()}
def forward(
self,
x: torch.Tensor,
mask: torch.Tensor | None = None,
return_logits: bool = False,
return_as_dict: bool = False,
return_as_list: bool = False,
return_as_df: bool = False,
threshold: float = 0.5, # only used for return_as_list=True
) -> torch.Tensor:
if return_as_df:
assert (
_PANDAS_AVAILABLE
), "`return_as_df=True` requires pandas to be installed"
# x.shape = (b, n, c, h, w)
b, n, c, h, w = x.shape
# x = rearrange(x, "b n c h w -> (b n) c h w")
x = x.reshape(b * n, c, h, w)
x = self.normalize(x)
# avg pooling
features = self.backbone(x)
# take last feature map
features = F.adaptive_avg_pool2d(features[-1], 1).flatten(1)
features = self.cnn_dropout(features)
# features = rearrange(features, "(b n) d -> b n d", b=b, n=n)
features = features.reshape(b, n, -1)
logits = self.head(features, mask=mask)
if return_logits:
# return raw logits
return logits
probas = logits.sigmoid()
if return_as_dict or return_as_df:
# list of dictionaries
batch_list = []
for i in range(probas.shape[0]):
dict_for_batch = {}
probas_i = probas[i]
for each_class in range(probas_i.shape[1]):
dict_for_batch[self.index2label[each_class]] = probas_i[
:, each_class
]
if return_as_df:
batch_list.append(
pd.DataFrame(
{k: v.cpu().numpy() for k, v in dict_for_batch.items()}
)
)
else:
batch_list.append(dict_for_batch)
return batch_list
if return_as_list:
# returns list of list of lists of strings
# innermost list - list of strings for each organ present based on threshold
# inner list - list of above for each slice
# outer list - list of above for each batch element (studies)
batch_list = []
# probas.shape = (batch_size, num_slices, num_classes)
for i in range(probas.shape[0]):
probas_i = probas[i]
# probas_i.shape = (num_slices, num_classes)
list_for_batch = []
for each_slice in range(probas_i.shape[0]):
for each_class in range(probas_i.shape[1]):
list_for_batch.append(
[
self.index2label[each_class]
for each_class in range(probas_i.shape[1])
if probas_i[each_slice, each_class] >= threshold
]
)
batch_list.append(list_for_batch)
return batch_list
return probas
def normalize(self, x: torch.Tensor) -> torch.Tensor:
# [0, 255] -> [-1, 1]
mini, maxi = 0.0, 255.0
x = (x - mini) / (maxi - mini)
x = (x - 0.5) * 2.0
return x
@staticmethod
def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]:
# applying windowing to CT
lower, upper = WL - WW // 2, WL + WW // 2
x = np.clip(x, lower, upper)
x = (x - lower) / (upper - lower)
return (x * 255.0).astype("uint8")
@staticmethod
def validate_windows_type(windows):
assert isinstance(windows, tuple) or isinstance(windows, list)
if isinstance(windows, tuple):
assert len(windows) == 2
assert [isinstance(_, int) for _ in windows]
elif isinstance(windows, list):
assert all([isinstance(_, tuple) for _ in windows])
assert all([len(_) == 2 for _ in windows])
assert all([isinstance(__, int) for _ in windows for __ in _])
@staticmethod
def determine_dicom_orientation(ds) -> int:
iop = ds.ImageOrientationPatient
# Calculate the direction cosine for the normal vector of the plane
normal_vector = np.cross(iop[:3], iop[3:])
# Determine the plane based on the largest component of the normal vector
abs_normal = np.abs(normal_vector)
if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]:
return 0 # sagittal
elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]:
return 1 # coronal
else:
return 2 # axial
def load_image_from_dicom(
self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None
) -> np.ndarray:
# windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH)
# or list of tuples if wishing to generate multi-channel image using
# > 1 window
if not _PYDICOM_AVAILABLE:
raise Exception("`pydicom` is not installed")
dicom = dcmread(path)
array = dicom.pixel_array.astype("float32")
m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept)
array = array * m + b
if windows is None:
return array
self.validate_windows_type(windows)
if isinstance(windows, tuple):
windows = [windows]
arr_list = []
for WL, WW in windows:
arr_list.append(self.window(array.copy(), WL, WW))
array = np.stack(arr_list, axis=-1)
if array.shape[-1] == 1:
array = np.squeeze(array, axis=-1)
return array
@staticmethod
def is_valid_dicom(
ds,
fname: str = "",
sort_by_instance_number: bool = False,
exclude_invalid_dicoms: bool = False,
) -> bool:
attributes = [
"pixel_array",
"RescaleSlope",
"RescaleIntercept",
]
if sort_by_instance_number:
attributes.append("InstanceNumber")
else:
attributes.append("ImagePositionPatient")
attributes.append("ImageOrientationPatient")
attributes_present = [hasattr(ds, attr) for attr in attributes]
valid = all(attributes_present)
if not valid and not exclude_invalid_dicoms:
raise Exception(
f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}"
)
return valid
@staticmethod
def most_common_element(lst):
return max(set(lst), key=lst.count)
@staticmethod
def center_crop_or_pad_borders(image, size):
height, width = image.shape[:2]
new_height, new_width = size
if new_height < height:
# crop top and bottom
crop_top = (height - new_height) // 2
crop_bottom = height - new_height - crop_top
image = image[crop_top:-crop_bottom]
elif new_height > height:
# pad top and bottom
pad_top = (new_height - height) // 2
pad_bottom = new_height - height - pad_top
image = np.pad(
image,
((pad_top, pad_bottom), (0, 0)),
mode="constant",
constant_values=0,
)
if new_width < width:
# crop left and right
crop_left = (width - new_width) // 2
crop_right = width - new_width - crop_left
image = image[:, crop_left:-crop_right]
elif new_width > width:
# pad left and right
pad_left = (new_width - width) // 2
pad_right = new_width - width - pad_left
image = np.pad(
image,
((0, 0), (pad_left, pad_right)),
mode="constant",
constant_values=0,
)
return image
def load_stack_from_dicom_folder(
self,
path: str,
windows: tuple[int, int] | list[tuple[int, int]] | None = None,
dicom_extension: str = ".dcm",
sort_by_instance_number: bool = False,
exclude_invalid_dicoms: bool = False,
fix_unequal_shapes: str = "crop_pad",
return_sorted_dicom_files: bool = False,
) -> np.ndarray | tuple[np.ndarray, list[str]]:
if not _PYDICOM_AVAILABLE:
raise Exception("`pydicom` is not installed")
dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}"))
if len(dicom_files) == 0:
raise Exception(
f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`"
)
dicoms = [dcmread(f) for f in dicom_files]
dicoms = [
(d, dicom_files[idx])
for idx, d in enumerate(dicoms)
if self.is_valid_dicom(
d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms
)
]
# handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True
# by only including valid DICOM filenames
dicom_files = [_[1] for _ in dicoms]
dicoms = [_[0] for _ in dicoms]
slices = [dcm.pixel_array.astype("float32") for dcm in dicoms]
shapes = np.stack([s.shape for s in slices], axis=0)
if not np.all(shapes == shapes[0]):
unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True)
standard_shape = tuple(unique_shapes[np.argmax(counts)])
print(
f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}"
)
if fix_unequal_shapes == "crop_pad":
slices = [
self.center_crop_or_pad_borders(s, standard_shape)
if s.shape != standard_shape
else s
for s in slices
]
elif fix_unequal_shapes == "resize":
slices = [
cv2.resize(s, standard_shape) if s.shape != standard_shape else s
for s in slices
]
slices = np.stack(slices, axis=0)
# find orientation
orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms]
# use most common
orientation = self.most_common_element(orientation)
# sort using ImagePositionPatient
# orientation is index to use for sorting
if sort_by_instance_number:
positions = [float(d.InstanceNumber) for d in dicoms]
else:
positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms]
indices = np.argsort(positions)
slices = slices[indices]
# rescale
m, b = (
[float(d.RescaleSlope) for d in dicoms],
[float(d.RescaleIntercept) for d in dicoms],
)
m, b = self.most_common_element(m), self.most_common_element(b)
slices = slices * m + b
if windows is not None:
self.validate_windows_type(windows)
if isinstance(windows, tuple):
windows = [windows]
arr_list = []
for WL, WW in windows:
arr_list.append(self.window(slices.copy(), WL, WW))
slices = np.stack(arr_list, axis=-1)
if slices.shape[-1] == 1:
slices = np.squeeze(slices, axis=-1)
if return_sorted_dicom_files:
return slices, [dicom_files[idx] for idx in indices]
return slices
def preprocess(
self,
x: np.ndarray,
mode: str = "2d",
torchify: bool = True,
add_batch_dim: bool = False,
device: str | torch.device | None = None,
) -> np.ndarray:
if device is not None:
assert torchify, "`torchify` must be `True` if specifying `device`"
mode = mode.lower()
if mode == "2d":
x = cv2.resize(x, self.image_size)
if x.ndim == 2:
x = x[:, :, np.newaxis]
elif mode == "3d":
x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0)
if x.ndim == 3:
x = x[:, :, :, np.newaxis]
if torchify:
if x.ndim == 3:
x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w")
elif x.ndim == 4:
x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w")
if add_batch_dim:
if torchify:
x = x.unsqueeze(0)
else:
x = x[np.newaxis]
if device is not None:
x = x.to(device)
return x
def crop_single_plane(
self,
x: np.ndarray,
device: str | torch.device,
organ: str | list[str],
threshold: float = 0.5,
buffer: float | int = 0,
speed_up: str | None = None,
) -> np.ndarray:
num_slices = x.shape[0]
if speed_up is not None:
assert speed_up in ["fast", "faster", "fastest"]
if speed_up == "fast":
# 75% of slices
reduce_num_slices = 3 * num_slices // 4
elif speed_up == "faster":
# 50% of slices
reduce_num_slices = num_slices // 2
elif speed_up == "fastest":
# 33% of slices
reduce_num_slices = num_slices // 3
indices = np.linspace(0, num_slices - 1, reduce_num_slices).astype(int)
x = x[indices]
x = self.preprocess(x, mode="3d")
x = torch.from_numpy(x)
x = rearrange(x, "n h w c -> n c h w").float().to(device)
x = rearrange(x, "n c h w -> 1 n c h w")
if x.size(2) > 1:
# if multi-channel, take mean
x = x.mean(2, keepdim=True)
organ_cls = self.forward(x)[0]
if speed_up is not None:
# organ_cls.shape = (num_slices, num_classes)
organ_cls = (
F.interpolate(
organ_cls.transpose(1, 0).unsqueeze(0),
size=(num_slices,),
mode="linear",
)
.squeeze(0)
.transpose(1, 0)
)
assert organ_cls.shape[0] == num_slices
slices = []
for each_organ in organ:
slices.append(
torch.where(organ_cls[:, self.label2index[each_organ]] >= threshold)[0]
)
slices = torch.cat(slices)
slice_min, slice_max = slices.min().item(), slices.max().item()
if buffer > 0:
if isinstance(buffer, float):
# % buffer
diff = slice_max - slice_min
buf = int(buffer * diff)
else:
# absolute slice buffer
buf = buffer
slice_min = max(0, slice_min - buf)
slice_max = min(num_slices - 1, slice_max + buf)
return slice_min, slice_max
@torch.no_grad()
def crop(
self,
x: np.ndarray,
organ: str | list[str],
crop_dims: int | list[int] = 0,
device: str | torch.device | None = None,
raw_hu: bool = False,
threshold: float = 0.5,
buffer: float | int = 0,
speed_up: str | None = None,
) -> (
np.ndarray
| tuple[np.ndarray, list[int]]
| tuple[np.ndarray, list[int], list[int]]
):
if device is None:
device = "cuda" if torch.cuda.is_available() else "cpu"
assert isinstance(x, np.ndarray)
assert x.ndim in {
3,
4,
}, f"x should be a 3D or 4D array, but got {x.ndim} dimensions"
if raw_hu:
# if input is in Hounsfield units, apply soft tissue window
x = self.window(x, WL=50, WW=400)
x0 = x
if not isinstance(organ, list):
organ = [organ]
if not isinstance(crop_dims, list):
crop_dims = [crop_dims]
assert max(crop_dims) <= 2
assert min(crop_dims) >= 0
if isinstance(buffer, float):
# percentage of cropped axis dimension
assert buffer < 1
if 0 in crop_dims:
smin0, smax0 = self.crop_single_plane(
x0, device, organ, threshold, buffer, speed_up
)
else:
smin0, smax0 = 0, x0.shape[0]
if 1 in crop_dims:
# swap plane
x = x0.swapaxes(1, 0)
smin1, smax1 = self.crop_single_plane(
x, device, organ, threshold, buffer, speed_up
)
else:
smin1, smax1 = 0, x0.shape[1]
if 2 in crop_dims:
# swap plane
x = x0.swapaxes(2, 0)
smin2, smax2 = self.crop_single_plane(
x, device, organ, threshold, buffer, speed_up
)
else:
smin2, smax2 = 0, x0.shape[2]
return x0[smin0 : smax0 + 1, smin1 : smax1 + 1, smin2 : smax2 + 1]