algohunt
initial_commit
c295391
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import torch.nn as nn
from .dpt_head import DPTHead
from .track_modules.base_track_predictor import BaseTrackerPredictor
class TrackHead(nn.Module):
"""
Track head that uses DPT head to process tokens and BaseTrackerPredictor for tracking.
The tracking is performed iteratively, refining predictions over multiple iterations.
"""
def __init__(
self,
dim_in,
patch_size=14,
features=128,
iters=4,
predict_conf=True,
stride=2,
corr_levels=7,
corr_radius=4,
hidden_size=384,
):
"""
Initialize the TrackHead module.
Args:
dim_in (int): Input dimension of tokens from the backbone.
patch_size (int): Size of image patches used in the vision transformer.
features (int): Number of feature channels in the feature extractor output.
iters (int): Number of refinement iterations for tracking predictions.
predict_conf (bool): Whether to predict confidence scores for tracked points.
stride (int): Stride value for the tracker predictor.
corr_levels (int): Number of correlation pyramid levels
corr_radius (int): Radius for correlation computation, controlling the search area.
hidden_size (int): Size of hidden layers in the tracker network.
"""
super().__init__()
self.patch_size = patch_size
# Feature extractor based on DPT architecture
# Processes tokens into feature maps for tracking
self.feature_extractor = DPTHead(
dim_in=dim_in,
patch_size=patch_size,
features=features,
feature_only=True, # Only output features, no activation
down_ratio=2, # Reduces spatial dimensions by factor of 2
pos_embed=False,
)
# Tracker module that predicts point trajectories
# Takes feature maps and predicts coordinates and visibility
self.tracker = BaseTrackerPredictor(
latent_dim=features, # Match the output_dim of feature extractor
predict_conf=predict_conf,
stride=stride,
corr_levels=corr_levels,
corr_radius=corr_radius,
hidden_size=hidden_size,
)
self.iters = iters
def forward(self, aggregated_tokens_list, images, patch_start_idx, query_points=None, iters=None):
"""
Forward pass of the TrackHead.
Args:
aggregated_tokens_list (list): List of aggregated tokens from the backbone.
images (torch.Tensor): Input images of shape (B, S, C, H, W) where:
B = batch size, S = sequence length.
patch_start_idx (int): Starting index for patch tokens.
query_points (torch.Tensor, optional): Initial query points to track.
If None, points are initialized by the tracker.
iters (int, optional): Number of refinement iterations. If None, uses self.iters.
Returns:
tuple:
- coord_preds (torch.Tensor): Predicted coordinates for tracked points.
- vis_scores (torch.Tensor): Visibility scores for tracked points.
- conf_scores (torch.Tensor): Confidence scores for tracked points (if predict_conf=True).
"""
B, S, _, H, W = images.shape
# Extract features from tokens
# feature_maps has shape (B, S, C, H//2, W//2) due to down_ratio=2
feature_maps = self.feature_extractor(aggregated_tokens_list, images, patch_start_idx)
# Use default iterations if not specified
if iters is None:
iters = self.iters
# Perform tracking using the extracted features
coord_preds, vis_scores, conf_scores = self.tracker(
query_points=query_points,
fmaps=feature_maps,
iters=iters,
)
return coord_preds, vis_scores, conf_scores