Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,607 Bytes
a51c6d2 d762d4a a51c6d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 |
import pycolmap
from models.SpaTrackV2.models.predictor import Predictor
import yaml
import easydict
import os
import numpy as np
import cv2
import torch
import torchvision.transforms as T
from PIL import Image
import io
import moviepy.editor as mp
from models.SpaTrackV2.utils.visualizer import Visualizer
import tqdm
from models.SpaTrackV2.models.utils import get_points_on_a_grid
import glob
from rich import print
import argparse
import decord
from huggingface_hub import hf_hub_download
config = {
"ckpt_dir": "Yuxihenry/SpatialTracker_Files", # HuggingFace repo ID
"cfg_dir": "config/magic_infer_moge.yaml",
}
def get_tracker_predictor(output_dir: str, vo_points: int = 756):
"""
Initialize and return the tracker predictor and visualizer
Args:
output_dir: Directory to save visualization results
vo_points: Number of points for visual odometry
Returns:
Tuple of (tracker_predictor, visualizer)
"""
viz = True
os.makedirs(output_dir, exist_ok=True)
with open(config["cfg_dir"], "r") as f:
cfg = yaml.load(f, Loader=yaml.FullLoader)
cfg = easydict.EasyDict(cfg)
cfg.out_dir = output_dir
cfg.model.track_num = vo_points
# Check if it's a local path or HuggingFace repo
if os.path.exists(config["ckpt_dir"]):
# Local file
model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"])
else:
# HuggingFace repo - download the model
print(f"Downloading model from HuggingFace: {config['ckpt_dir']}")
checkpoint_path = hf_hub_download(
repo_id=config["ckpt_dir"],
repo_type="model",
filename="SpaTrack3_offline.pth"
)
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"])
model.eval()
model.to("cuda")
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True,
fps=10, pad_value=0, tracks_leave_trace=5)
return model, viser
def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3):
"""
Run tracking on a video sequence
Args:
model: Tracker predictor instance
viser: Visualizer instance
temp_dir: Directory containing temporary files
video_name: Name of the video file (without extension)
grid_size: Size of the tracking grid
vo_points: Number of points for visual odometry
fps: Frames per second for visualization
"""
# Setup paths
video_path = os.path.join(temp_dir, f"{video_name}.mp4")
mask_path = os.path.join(temp_dir, f"{video_name}.png")
out_dir = os.path.join(temp_dir, "results")
os.makedirs(out_dir, exist_ok=True)
# Load video using decord
video_reader = decord.VideoReader(video_path)
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W)
# resize make sure the shortest side is 336
h, w = video_tensor.shape[2:]
scale = max(336 / h, 336 / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video_tensor = T.Resize((new_h, new_w))(video_tensor)
video_tensor = video_tensor[::fps].float()
depth_tensor = None
intrs = None
extrs = None
data_npz_load = {}
# Load and process mask
if os.path.exists(mask_path):
mask = cv2.imread(mask_path)
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2]))
mask = mask.sum(axis=-1)>0
else:
mask = np.ones_like(video_tensor[0,0].numpy())>0
# Get frame dimensions and create grid points
frame_H, frame_W = video_tensor.shape[2:]
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu")
# Sample mask values at grid points and filter out points where mask=0
if os.path.exists(mask_path):
grid_pts_int = grid_pts[0].long()
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]]
grid_pts = grid_pts[:, mask_values]
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy()
# Run model inference
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16):
(
c2w_traj, intrs, point_map, conf_depth,
track3d_pred, track2d_pred, vis_pred, conf_pred, video
) = model.forward(video_tensor, depth=depth_tensor,
intrs=intrs, extrs=extrs,
queries=query_xyt,
fps=1, full_point=False, iters_track=4,
query_no_BA=True, fixed_cam=False, stage=1,
support_frame=len(video_tensor)-1, replace_ratio=0.2)
# Resize results to avoid too large I/O Burden
max_size = 336
h, w = video.shape[2:]
scale = min(max_size / h, max_size / w)
if scale < 1:
new_h, new_w = int(h * scale), int(w * scale)
video = T.Resize((new_h, new_w))(video)
video_tensor = T.Resize((new_h, new_w))(video_tensor)
point_map = T.Resize((new_h, new_w))(point_map)
track2d_pred[...,:2] = track2d_pred[...,:2] * scale
intrs[:,:2,:] = intrs[:,:2,:] * scale
if depth_tensor is not None:
depth_tensor = T.Resize((new_h, new_w))(depth_tensor)
conf_depth = T.Resize((new_h, new_w))(conf_depth)
# Visualize tracks
viser.visualize(video=video[None],
tracks=track2d_pred[None][...,:2],
visibility=vis_pred[None],filename="test")
# Save in tapip3d format
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy()
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy()
data_npz_load["intrinsics"] = intrs.cpu().numpy()
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy()
data_npz_load["video"] = (video_tensor).cpu().numpy()/255
data_npz_load["visibs"] = vis_pred.cpu().numpy()
data_npz_load["confs"] = conf_pred.cpu().numpy()
data_npz_load["confs_depth"] = conf_depth.cpu().numpy()
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load)
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]") |