import cv2 import glob import numpy as np import os import torch import torch.nn as nn import torch.nn.functional as F from einops import rearrange from transformers import PreTrainedModel from timm import create_model from .configuration import TotalClassifierConfig from .label2index import label2index _PYDICOM_AVAILABLE = False try: from pydicom import dcmread _PYDICOM_AVAILABLE = True except ModuleNotFoundError: pass _PANDAS_AVAILABLE = False try: import pandas as pd _PANDAS_AVAILABLE = True except ModuleNotFoundError: pass class RNNHead(nn.Module): def __init__( self, rnn_type: str, rnn_num_layers: int, rnn_dropout: float, feature_dim: int, linear_dropout: float, num_classes: int, ): super().__init__() self.rnn = getattr(nn, rnn_type)( input_size=feature_dim, hidden_size=feature_dim // 2, num_layers=rnn_num_layers, dropout=rnn_dropout, batch_first=True, bidirectional=True, ) self.dropout = nn.Dropout(linear_dropout) self.linear = nn.Linear(feature_dim, num_classes) @staticmethod def convert_seq_and_mask_to_packed_sequence( seq: torch.Tensor, mask: torch.Tensor ) -> tuple[torch.Tensor, torch.Tensor]: assert seq.shape[0] == mask.shape[0] lengths = mask.sum(1) seq = nn.utils.rnn.pack_padded_sequence( seq, lengths.cpu().int(), batch_first=True, enforce_sorted=False ) return seq def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None ) -> torch.Tensor: skip = x if mask is not None: # convert to PackedSequence L = x.shape[1] x = self.convert_seq_and_mask_to_packed_sequence(x, mask) x, _ = self.rnn(x) if mask is not None: # convert back to tensor x = nn.utils.rnn.pad_packed_sequence(x, batch_first=True, total_length=L)[0] x = x + skip return self.linear(self.dropout(x)) class TotalClassifierModel(PreTrainedModel): config_class = TotalClassifierConfig def __init__(self, config): super().__init__(config) self.image_size = config.image_size self.backbone = create_model( model_name=config.backbone, pretrained=False, num_classes=0, global_pool="", features_only=True, in_chans=config.in_chans, ) self.cnn_dropout = nn.Dropout(p=config.cnn_dropout) self.head = RNNHead( rnn_type=config.rnn_type, rnn_num_layers=config.rnn_num_layers, rnn_dropout=config.rnn_dropout, feature_dim=config.feature_dim, linear_dropout=config.linear_dropout, num_classes=config.num_classes, ) self.label2index = label2index self.index2label = {v: k for k, v in self.label2index.items()} def forward( self, x: torch.Tensor, mask: torch.Tensor | None = None, return_logits: bool = False, return_as_dict: bool = False, return_as_list: bool = False, return_as_df: bool = False, threshold: float = 0.5, # only used for return_as_list=True ) -> torch.Tensor: if return_as_df: assert ( _PANDAS_AVAILABLE ), "`return_as_df=True` requires pandas to be installed" # x.shape = (b, n, c, h, w) b, n, c, h, w = x.shape # x = rearrange(x, "b n c h w -> (b n) c h w") x = x.reshape(b * n, c, h, w) x = self.normalize(x) # avg pooling features = self.backbone(x) # take last feature map features = F.adaptive_avg_pool2d(features[-1], 1).flatten(1) features = self.cnn_dropout(features) # features = rearrange(features, "(b n) d -> b n d", b=b, n=n) features = features.reshape(b, n, -1) logits = self.head(features, mask=mask) if return_logits: # return raw logits return logits probas = logits.sigmoid() if return_as_dict or return_as_df: # list of dictionaries batch_list = [] for i in range(probas.shape[0]): dict_for_batch = {} probas_i = probas[i] for each_class in range(probas_i.shape[1]): dict_for_batch[self.index2label[each_class]] = probas_i[ :, each_class ] if return_as_df: batch_list.append( pd.DataFrame( {k: v.cpu().numpy() for k, v in dict_for_batch.items()} ) ) else: batch_list.append(dict_for_batch) return batch_list if return_as_list: # returns list of list of lists of strings # innermost list - list of strings for each organ present based on threshold # inner list - list of above for each slice # outer list - list of above for each batch element (studies) batch_list = [] # probas.shape = (batch_size, num_slices, num_classes) for i in range(probas.shape[0]): probas_i = probas[i] # probas_i.shape = (num_slices, num_classes) list_for_batch = [] for each_slice in range(probas_i.shape[0]): for each_class in range(probas_i.shape[1]): list_for_batch.append( [ self.index2label[each_class] for each_class in range(probas_i.shape[1]) if probas_i[each_slice, each_class] >= threshold ] ) batch_list.append(list_for_batch) return batch_list return probas def normalize(self, x: torch.Tensor) -> torch.Tensor: # [0, 255] -> [-1, 1] mini, maxi = 0.0, 255.0 x = (x - mini) / (maxi - mini) x = (x - 0.5) * 2.0 return x @staticmethod def window(x: np.ndarray, WL: int, WW: int) -> np.ndarray[np.uint8]: # applying windowing to CT lower, upper = WL - WW // 2, WL + WW // 2 x = np.clip(x, lower, upper) x = (x - lower) / (upper - lower) return (x * 255.0).astype("uint8") @staticmethod def validate_windows_type(windows): assert isinstance(windows, tuple) or isinstance(windows, list) if isinstance(windows, tuple): assert len(windows) == 2 assert [isinstance(_, int) for _ in windows] elif isinstance(windows, list): assert all([isinstance(_, tuple) for _ in windows]) assert all([len(_) == 2 for _ in windows]) assert all([isinstance(__, int) for _ in windows for __ in _]) @staticmethod def determine_dicom_orientation(ds) -> int: iop = ds.ImageOrientationPatient # Calculate the direction cosine for the normal vector of the plane normal_vector = np.cross(iop[:3], iop[3:]) # Determine the plane based on the largest component of the normal vector abs_normal = np.abs(normal_vector) if abs_normal[0] > abs_normal[1] and abs_normal[0] > abs_normal[2]: return 0 # sagittal elif abs_normal[1] > abs_normal[0] and abs_normal[1] > abs_normal[2]: return 1 # coronal else: return 2 # axial def load_image_from_dicom( self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None ) -> np.ndarray: # windows can be tuple of (WINDOW_LEVEL, WINDOW_WIDTH) # or list of tuples if wishing to generate multi-channel image using # > 1 window if not _PYDICOM_AVAILABLE: raise Exception("`pydicom` is not installed") dicom = dcmread(path) array = dicom.pixel_array.astype("float32") m, b = float(dicom.RescaleSlope), float(dicom.RescaleIntercept) array = array * m + b if windows is None: return array self.validate_windows_type(windows) if isinstance(windows, tuple): windows = [windows] arr_list = [] for WL, WW in windows: arr_list.append(self.window(array.copy(), WL, WW)) array = np.stack(arr_list, axis=-1) if array.shape[-1] == 1: array = np.squeeze(array, axis=-1) return array @staticmethod def is_valid_dicom( ds, fname: str = "", sort_by_instance_number: bool = False, exclude_invalid_dicoms: bool = False, ) -> bool: attributes = [ "pixel_array", "RescaleSlope", "RescaleIntercept", ] if sort_by_instance_number: attributes.append("InstanceNumber") else: attributes.append("ImagePositionPatient") attributes.append("ImageOrientationPatient") attributes_present = [hasattr(ds, attr) for attr in attributes] valid = all(attributes_present) if not valid and not exclude_invalid_dicoms: raise Exception( f"invalid DICOM file [{fname}]: missing attributes: {list(np.array(attributes)[~np.array(attributes_present)])}" ) return valid @staticmethod def most_common_element(lst): return max(set(lst), key=lst.count) @staticmethod def center_crop_or_pad_borders(image, size): height, width = image.shape[:2] new_height, new_width = size if new_height < height: # crop top and bottom crop_top = (height - new_height) // 2 crop_bottom = height - new_height - crop_top image = image[crop_top:-crop_bottom] elif new_height > height: # pad top and bottom pad_top = (new_height - height) // 2 pad_bottom = new_height - height - pad_top image = np.pad( image, ((pad_top, pad_bottom), (0, 0)), mode="constant", constant_values=0, ) if new_width < width: # crop left and right crop_left = (width - new_width) // 2 crop_right = width - new_width - crop_left image = image[:, crop_left:-crop_right] elif new_width > width: # pad left and right pad_left = (new_width - width) // 2 pad_right = new_width - width - pad_left image = np.pad( image, ((0, 0), (pad_left, pad_right)), mode="constant", constant_values=0, ) return image def load_stack_from_dicom_folder( self, path: str, windows: tuple[int, int] | list[tuple[int, int]] | None = None, dicom_extension: str = ".dcm", sort_by_instance_number: bool = False, exclude_invalid_dicoms: bool = False, fix_unequal_shapes: str = "crop_pad", return_sorted_dicom_files: bool = False, ) -> np.ndarray | tuple[np.ndarray, list[str]]: if not _PYDICOM_AVAILABLE: raise Exception("`pydicom` is not installed") dicom_files = glob.glob(os.path.join(path, f"*{dicom_extension}")) if len(dicom_files) == 0: raise Exception( f"No DICOM files found in `{path}` using `dicom_extension={dicom_extension}`" ) dicoms = [dcmread(f) for f in dicom_files] dicoms = [ (d, dicom_files[idx]) for idx, d in enumerate(dicoms) if self.is_valid_dicom( d, dicom_files[idx], sort_by_instance_number, exclude_invalid_dicoms ) ] # handles exclude_invalid_dicoms=True and return_sorted_dicom_files=True # by only including valid DICOM filenames dicom_files = [_[1] for _ in dicoms] dicoms = [_[0] for _ in dicoms] slices = [dcm.pixel_array.astype("float32") for dcm in dicoms] shapes = np.stack([s.shape for s in slices], axis=0) if not np.all(shapes == shapes[0]): unique_shapes, counts = np.unique(shapes, axis=0, return_counts=True) standard_shape = tuple(unique_shapes[np.argmax(counts)]) print( f"warning: different array shapes present, using {fix_unequal_shapes} -> {standard_shape}" ) if fix_unequal_shapes == "crop_pad": slices = [ self.center_crop_or_pad_borders(s, standard_shape) if s.shape != standard_shape else s for s in slices ] elif fix_unequal_shapes == "resize": slices = [ cv2.resize(s, standard_shape) if s.shape != standard_shape else s for s in slices ] slices = np.stack(slices, axis=0) # find orientation orientation = [self.determine_dicom_orientation(dcm) for dcm in dicoms] # use most common orientation = self.most_common_element(orientation) # sort using ImagePositionPatient # orientation is index to use for sorting if sort_by_instance_number: positions = [float(d.InstanceNumber) for d in dicoms] else: positions = [float(d.ImagePositionPatient[orientation]) for d in dicoms] indices = np.argsort(positions) slices = slices[indices] # rescale m, b = ( [float(d.RescaleSlope) for d in dicoms], [float(d.RescaleIntercept) for d in dicoms], ) m, b = self.most_common_element(m), self.most_common_element(b) slices = slices * m + b if windows is not None: self.validate_windows_type(windows) if isinstance(windows, tuple): windows = [windows] arr_list = [] for WL, WW in windows: arr_list.append(self.window(slices.copy(), WL, WW)) slices = np.stack(arr_list, axis=-1) if slices.shape[-1] == 1: slices = np.squeeze(slices, axis=-1) if return_sorted_dicom_files: return slices, [dicom_files[idx] for idx in indices] return slices def preprocess( self, x: np.ndarray, mode: str = "2d", torchify: bool = True, add_batch_dim: bool = False, device: str | torch.device | None = None, ) -> np.ndarray: if device is not None: assert torchify, "`torchify` must be `True` if specifying `device`" mode = mode.lower() if mode == "2d": x = cv2.resize(x, self.image_size) if x.ndim == 2: x = x[:, :, np.newaxis] elif mode == "3d": x = np.stack([cv2.resize(s, self.image_size) for s in x], axis=0) if x.ndim == 3: x = x[:, :, :, np.newaxis] if torchify: if x.ndim == 3: x = rearrange(torch.from_numpy(x).float(), "h w c -> c h w") elif x.ndim == 4: x = rearrange(torch.from_numpy(x).float(), "n h w c -> n c h w") if add_batch_dim: if torchify: x = x.unsqueeze(0) else: x = x[np.newaxis] if device is not None: x = x.to(device) return x def crop_single_plane( self, x: np.ndarray, device: str | torch.device, organ: str | list[str], threshold: float = 0.5, buffer: float | int = 0, speed_up: str | None = None, ) -> np.ndarray: num_slices = x.shape[0] if speed_up is not None: assert speed_up in ["fast", "faster", "fastest"] if speed_up == "fast": # 75% of slices reduce_num_slices = 3 * num_slices // 4 elif speed_up == "faster": # 50% of slices reduce_num_slices = num_slices // 2 elif speed_up == "fastest": # 33% of slices reduce_num_slices = num_slices // 3 indices = np.linspace(0, num_slices - 1, reduce_num_slices).astype(int) x = x[indices] x = self.preprocess(x, mode="3d") x = torch.from_numpy(x) x = rearrange(x, "n h w c -> n c h w").float().to(device) x = rearrange(x, "n c h w -> 1 n c h w") if x.size(2) > 1: # if multi-channel, take mean x = x.mean(2, keepdim=True) organ_cls = self.forward(x)[0] if speed_up is not None: # organ_cls.shape = (num_slices, num_classes) organ_cls = ( F.interpolate( organ_cls.transpose(1, 0).unsqueeze(0), size=(num_slices,), mode="linear", ) .squeeze(0) .transpose(1, 0) ) assert organ_cls.shape[0] == num_slices slices = [] for each_organ in organ: slices.append( torch.where(organ_cls[:, self.label2index[each_organ]] >= threshold)[0] ) slices = torch.cat(slices) slice_min, slice_max = slices.min().item(), slices.max().item() if buffer > 0: if isinstance(buffer, float): # % buffer diff = slice_max - slice_min buf = int(buffer * diff) else: # absolute slice buffer buf = buffer slice_min = max(0, slice_min - buf) slice_max = min(num_slices - 1, slice_max + buf) return slice_min, slice_max @torch.no_grad() def crop( self, x: np.ndarray, organ: str | list[str], crop_dims: int | list[int] = 0, device: str | torch.device | None = None, raw_hu: bool = False, threshold: float = 0.5, buffer: float | int = 0, speed_up: str | None = None, ) -> ( np.ndarray | tuple[np.ndarray, list[int]] | tuple[np.ndarray, list[int], list[int]] ): if device is None: device = "cuda" if torch.cuda.is_available() else "cpu" assert isinstance(x, np.ndarray) assert x.ndim in { 3, 4, }, f"x should be a 3D or 4D array, but got {x.ndim} dimensions" if raw_hu: # if input is in Hounsfield units, apply soft tissue window x = self.window(x, WL=50, WW=400) x0 = x if not isinstance(organ, list): organ = [organ] if not isinstance(crop_dims, list): crop_dims = [crop_dims] assert max(crop_dims) <= 2 assert min(crop_dims) >= 0 if isinstance(buffer, float): # percentage of cropped axis dimension assert buffer < 1 if 0 in crop_dims: smin0, smax0 = self.crop_single_plane( x0, device, organ, threshold, buffer, speed_up ) else: smin0, smax0 = 0, x0.shape[0] if 1 in crop_dims: # swap plane x = x0.swapaxes(1, 0) smin1, smax1 = self.crop_single_plane( x, device, organ, threshold, buffer, speed_up ) else: smin1, smax1 = 0, x0.shape[1] if 2 in crop_dims: # swap plane x = x0.swapaxes(2, 0) smin2, smax2 = self.crop_single_plane( x, device, organ, threshold, buffer, speed_up ) else: smin2, smax2 = 0, x0.shape[2] return x0[smin0 : smax0 + 1, smin1 : smax1 + 1, smin2 : smax2 + 1]