Whole Body Segmentation
This model is a whole body segmentation model based on the SegResNet architecture. It was fine-tuned on CT-FM
Running instructions
Whole Body Segmentation Inference
This notebook demonstrates how to:
- Load a pre-trained whole body segmentation model from HuggingFace Hub
- Set up preprocessing and postprocessing pipelines
- Perform sliding window inference on CT volumes
- Save the segmentation results
The model segments 118 different anatomical structures from CT scans.
Setup
Install requirements and import necessary packages
# Install lighter_zoo package
%pip install lighter_zoo -U -qq
Note: you may need to restart the kernel to use updated packages.
# Imports
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
Compose, LoadImage, EnsureType, Orientation,
ScaleIntensityRange, CropForeground, Invert,
Activations, AsDiscrete, KeepLargestConnectedComponent,
SaveImage
)
from monai.inferers import SlidingWindowInferer
Note: you may need to restart the kernel to use updated packages.
Load Model
Download and initialize the pre-trained model from HuggingFace Hub
# Load pre-trained model
model = SegResNet.from_pretrained(
"project-lighter/whole_body_segmentation",
force_download=True
)
config.json: 0%| | 0.00/162 [00:00<?, ?B/s]
model.safetensors: 0%| | 0.00/349M [00:00<?, ?B/s]
Configure Inference
Set up sliding window inference for processing large volumes
# Configure sliding window inference
inferer = SlidingWindowInferer(
roi_size=[96, 160, 160], # Size of patches to process
sw_batch_size=2, # Number of windows to process in parallel
overlap=0.625, # Overlap between windows (reduces boundary artifacts)
mode="gaussian" # Gaussian weighting for overlap regions
)
Setup Processing Pipelines
Define preprocessing and postprocessing transforms
# Preprocessing pipeline
preprocess = Compose([
LoadImage(ensure_channel_first=True), # Load image and ensure channel dimension
EnsureType(), # Ensure correct data type
Orientation(axcodes="SPL"), # Standardize orientation
# Scale intensity to [0,1] range, clipping outliers
ScaleIntensityRange(
a_min=-1024, # Min HU value
a_max=2048, # Max HU value
b_min=0, # Target min
b_max=1, # Target max
clip=True # Clip values outside range
),
CropForeground() # Remove background to reduce computation
])
# Postprocessing pipeline
postprocess = Compose([
Activations(softmax=True), # Apply softmax to get probabilities
AsDiscrete(argmax=True, dtype=torch.int32), # Convert to class labels
KeepLargestConnectedComponent(), # Remove small disconnected regions
Invert(transform=preprocess), # Restore original space
# Save the result
SaveImage(output_dir="./segmentations")
])
/home/suraj/miniconda3/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
warn_deprecated(argname, msg, warning_category)
Run Inference
Process an input CT scan and generate segmentation
# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"
# Preprocess input
input_tensor = preprocess(input_path)
# Run inference
with torch.no_grad():
output = inferer(input_tensor.unsqueeze(dim=0), model)[0]
# Copy metadata from input
output.applied_operations = input_tensor.applied_operations
output.affine = input_tensor.affine
# Postprocess and save result
result = postprocess(output[0])
print("โ
Segmentation completed and saved")
2025-01-16 18:41:57,674 INFO image_writer.py:197 - writing: /home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/segmentations/0/0_trans.nii.gz
โ
Segmentation completed and saved
- Downloads last month
- 97
Inference Providers
NEW
This model isn't deployed by any Inference Provider.
๐
Ask for provider support