File size: 6,607 Bytes
a51c6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d762d4a
a51c6d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import pycolmap
from models.SpaTrackV2.models.predictor import Predictor
import yaml
import easydict
import os
import numpy as np
import cv2
import torch
import torchvision.transforms as T
from PIL import Image
import io
import moviepy.editor as mp
from models.SpaTrackV2.utils.visualizer import Visualizer
import tqdm
from models.SpaTrackV2.models.utils import get_points_on_a_grid
import glob
from rich import print
import argparse
import decord
from huggingface_hub import hf_hub_download

config = {
    "ckpt_dir": "Yuxihenry/SpatialTracker_Files",  # HuggingFace repo ID
    "cfg_dir": "config/magic_infer_moge.yaml",
}

def get_tracker_predictor(output_dir: str, vo_points: int = 756):
    """
    Initialize and return the tracker predictor and visualizer
    Args:
        output_dir: Directory to save visualization results
        vo_points: Number of points for visual odometry
    Returns:
        Tuple of (tracker_predictor, visualizer)
    """
    viz = True
    os.makedirs(output_dir, exist_ok=True)
        
    with open(config["cfg_dir"], "r") as f:
        cfg = yaml.load(f, Loader=yaml.FullLoader)
    cfg = easydict.EasyDict(cfg)
    cfg.out_dir = output_dir
    cfg.model.track_num = vo_points
    
    # Check if it's a local path or HuggingFace repo
    if os.path.exists(config["ckpt_dir"]):
        # Local file
        model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
    else:
        # HuggingFace repo - download the model
        print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
        checkpoint_path = hf_hub_download(
            repo_id=config["ckpt_dir"],
            repo_type="model",
            filename="SpaTrack3_offline.pth"
        )
        model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
    
    model.eval()
    model.to("cuda")
    viser = Visualizer(save_dir=cfg.out_dir, grayscale=True, 
                     fps=10, pad_value=0, tracks_leave_trace=5)

    return model, viser

def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
    """
    Run tracking on a video sequence
    Args:
        model: Tracker predictor instance
        viser: Visualizer instance
        temp_dir: Directory containing temporary files
        video_name: Name of the video file (without extension)
        grid_size: Size of the tracking grid
        vo_points: Number of points for visual odometry
        fps: Frames per second for visualization
    """
    # Setup paths
    video_path = os.path.join(temp_dir, f"{video_name}.mp4")
    mask_path = os.path.join(temp_dir, f"{video_name}.png")
    out_dir = os.path.join(temp_dir, "results")
    os.makedirs(out_dir, exist_ok=True)
    
    # Load video using decord
    video_reader = decord.VideoReader(video_path)
    video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2)  # Convert to tensor and permute to (N, C, H, W)
    # resize make sure the shortest side is 336
    h, w = video_tensor.shape[2:]
    scale = max(336 / h, 336 / w)
    if scale < 1:
        new_h, new_w = int(h * scale), int(w * scale)
        video_tensor = T.Resize((new_h, new_w))(video_tensor)
    video_tensor = video_tensor[::fps].float()
    depth_tensor = None
    intrs = None
    extrs = None
    data_npz_load = {}
    
    # Load and process mask
    if os.path.exists(mask_path):
        mask = cv2.imread(mask_path)
        mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
        mask = mask.sum(axis=-1)>0
    else:
        mask = np.ones_like(video_tensor[0,0].numpy())>0
    
    # Get frame dimensions and create grid points
    frame_H, frame_W = video_tensor.shape[2:]
    grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
    
    # Sample mask values at grid points and filter out points where mask=0
    if os.path.exists(mask_path):
        grid_pts_int = grid_pts[0].long()
        mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
        grid_pts = grid_pts[:, mask_values]
    
    query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
    
    # Run model inference
    with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
        (
            c2w_traj, intrs, point_map, conf_depth,
            track3d_pred, track2d_pred, vis_pred, conf_pred, video
        ) = model.forward(video_tensor, depth=depth_tensor,
                            intrs=intrs, extrs=extrs, 
                            queries=query_xyt,
                            fps=1, full_point=False, iters_track=4,
                            query_no_BA=True, fixed_cam=False, stage=1,
                            support_frame=len(video_tensor)-1, replace_ratio=0.2) 
        
        # Resize results to avoid too large I/O Burden
        max_size = 336
        h, w = video.shape[2:]
        scale = min(max_size / h, max_size / w)
        if scale < 1:
            new_h, new_w = int(h * scale), int(w * scale)
            video = T.Resize((new_h, new_w))(video)
            video_tensor = T.Resize((new_h, new_w))(video_tensor)
            point_map = T.Resize((new_h, new_w))(point_map)
            track2d_pred[...,:2] = track2d_pred[...,:2] * scale
            intrs[:,:2,:] = intrs[:,:2,:] * scale
            if depth_tensor is not None:
                depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
            conf_depth = T.Resize((new_h, new_w))(conf_depth)
        
        # Visualize tracks
        viser.visualize(video=video[None],
                        tracks=track2d_pred[None][...,:2],
                        visibility=vis_pred[None],filename="test")
                        
        # Save in tapip3d format
        data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
        data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
        data_npz_load["intrinsics"] = intrs.cpu().numpy()
        data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
        data_npz_load["video"] = (video_tensor).cpu().numpy()/255
        data_npz_load["visibs"] = vis_pred.cpu().numpy()
        data_npz_load["confs"] = conf_pred.cpu().numpy()
        data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
        np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
        
        print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]")