Whole Body Segmentation

This model is a whole body segmentation model based on the SegResNet architecture. It was fine-tuned on CT-FM

Running instructions

Whole Body Segmentation Inference

This notebook demonstrates how to:

  1. Load a pre-trained whole body segmentation model from HuggingFace Hub
  2. Set up preprocessing and postprocessing pipelines
  3. Perform sliding window inference on CT volumes
  4. Save the segmentation results

The model segments 118 different anatomical structures from CT scans.

Setup

Install requirements and import necessary packages

# Install lighter_zoo package
%pip install lighter_zoo -U -qq
Note: you may need to restart the kernel to use updated packages.
# Imports
import torch
from lighter_zoo import SegResNet
from monai.transforms import (
    Compose, LoadImage, EnsureType, Orientation,
    ScaleIntensityRange, CropForeground, Invert,
    Activations, AsDiscrete, KeepLargestConnectedComponent,
    SaveImage
)
from monai.inferers import SlidingWindowInferer
Note: you may need to restart the kernel to use updated packages.

Load Model

Download and initialize the pre-trained model from HuggingFace Hub

# Load pre-trained model
model = SegResNet.from_pretrained(
    "project-lighter/whole_body_segmentation",
    force_download=True
)
config.json:   0%|          | 0.00/162 [00:00<?, ?B/s]



model.safetensors:   0%|          | 0.00/349M [00:00<?, ?B/s]

Configure Inference

Set up sliding window inference for processing large volumes

# Configure sliding window inference
inferer = SlidingWindowInferer(
    roi_size=[96, 160, 160],  # Size of patches to process
    sw_batch_size=2,          # Number of windows to process in parallel
    overlap=0.625,            # Overlap between windows (reduces boundary artifacts)
    mode="gaussian"           # Gaussian weighting for overlap regions
)

Setup Processing Pipelines

Define preprocessing and postprocessing transforms

# Preprocessing pipeline
preprocess = Compose([
    LoadImage(ensure_channel_first=True),  # Load image and ensure channel dimension
    EnsureType(),                         # Ensure correct data type
    Orientation(axcodes="SPL"),           # Standardize orientation
    # Scale intensity to [0,1] range, clipping outliers
    ScaleIntensityRange(
        a_min=-1024,    # Min HU value
        a_max=2048,     # Max HU value
        b_min=0,        # Target min
        b_max=1,        # Target max
        clip=True       # Clip values outside range
    ),
    CropForeground()    # Remove background to reduce computation
])

# Postprocessing pipeline
postprocess = Compose([
    Activations(softmax=True),              # Apply softmax to get probabilities
    AsDiscrete(argmax=True, dtype=torch.int32),  # Convert to class labels
    KeepLargestConnectedComponent(),        # Remove small disconnected regions
    Invert(transform=preprocess),           # Restore original space
    # Save the result
    SaveImage(output_dir="./segmentations")
])
/home/suraj/miniconda3/lib/python3.10/site-packages/monai/utils/deprecate_utils.py:321: FutureWarning: monai.transforms.croppad.array CropForeground.__init__:allow_smaller: Current default value of argument `allow_smaller=True` has been deprecated since version 1.2. It will be changed to `allow_smaller=False` in version 1.5.
  warn_deprecated(argname, msg, warning_category)

Run Inference

Process an input CT scan and generate segmentation

# Input path
input_path = "/home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/scans/s0114.nii.gz"

# Preprocess input
input_tensor = preprocess(input_path)

# Run inference
with torch.no_grad():
    output = inferer(input_tensor.unsqueeze(dim=0), model)[0]

# Copy metadata from input
output.applied_operations = input_tensor.applied_operations
output.affine = input_tensor.affine

# Postprocess and save result
result = postprocess(output[0])
print("โœ… Segmentation completed and saved")
2025-01-16 18:41:57,674 INFO image_writer.py:197 - writing: /home/suraj/Repositories/lighter-ct-fm/semantic-search-app/assets/segmentations/0/0_trans.nii.gz
โœ… Segmentation completed and saved
Downloads last month
97
Safetensors
Model size
87.2M params
Tensor type
F32
ยท
Inference Providers NEW
This model isn't deployed by any Inference Provider. ๐Ÿ™‹ Ask for provider support