Spaces:
Running
on
Zero
Running
on
Zero
import pycolmap | |
from models.SpaTrackV2.models.predictor import Predictor | |
import yaml | |
import easydict | |
import os | |
import numpy as np | |
import cv2 | |
import torch | |
import torchvision.transforms as T | |
from PIL import Image | |
import io | |
import moviepy.editor as mp | |
from models.SpaTrackV2.utils.visualizer import Visualizer | |
import tqdm | |
from models.SpaTrackV2.models.utils import get_points_on_a_grid | |
import glob | |
from rich import print | |
import argparse | |
import decord | |
from huggingface_hub import hf_hub_download | |
config = { | |
"ckpt_dir": "Yuxihenry/SpatialTracker_Files", # HuggingFace repo ID | |
"cfg_dir": "config/magic_infer_moge.yaml", | |
} | |
def get_tracker_predictor(output_dir: str, vo_points: int = 756): | |
""" | |
Initialize and return the tracker predictor and visualizer | |
Args: | |
output_dir: Directory to save visualization results | |
vo_points: Number of points for visual odometry | |
Returns: | |
Tuple of (tracker_predictor, visualizer) | |
""" | |
viz = True | |
os.makedirs(output_dir, exist_ok=True) | |
with open(config["cfg_dir"], "r") as f: | |
cfg = yaml.load(f, Loader=yaml.FullLoader) | |
cfg = easydict.EasyDict(cfg) | |
cfg.out_dir = output_dir | |
cfg.model.track_num = vo_points | |
# Check if it's a local path or HuggingFace repo | |
if os.path.exists(config["ckpt_dir"]): | |
# Local file | |
model = Predictor.from_pretrained(config["ckpt_dir"], model_cfg=cfg["model"]) | |
else: | |
# HuggingFace repo - download the model | |
print(f"Downloading model from HuggingFace: {config['ckpt_dir']}") | |
checkpoint_path = hf_hub_download( | |
repo_id=config["ckpt_dir"], | |
repo_type="model", | |
filename="SpaTrack3_offline.pth" | |
) | |
model = Predictor.from_pretrained(checkpoint_path, model_cfg=cfg["model"]) | |
model.eval() | |
model.to("cuda") | |
viser = Visualizer(save_dir=cfg.out_dir, grayscale=True, | |
fps=10, pad_value=0, tracks_leave_trace=5) | |
return model, viser | |
def run_tracker(model, viser, temp_dir, video_name, grid_size, vo_points, fps=3): | |
""" | |
Run tracking on a video sequence | |
Args: | |
model: Tracker predictor instance | |
viser: Visualizer instance | |
temp_dir: Directory containing temporary files | |
video_name: Name of the video file (without extension) | |
grid_size: Size of the tracking grid | |
vo_points: Number of points for visual odometry | |
fps: Frames per second for visualization | |
""" | |
# Setup paths | |
video_path = os.path.join(temp_dir, f"{video_name}.mp4") | |
mask_path = os.path.join(temp_dir, f"{video_name}.png") | |
out_dir = os.path.join(temp_dir, "results") | |
os.makedirs(out_dir, exist_ok=True) | |
# Load video using decord | |
video_reader = decord.VideoReader(video_path) | |
video_tensor = torch.from_numpy(video_reader.get_batch(range(len(video_reader))).asnumpy()).permute(0, 3, 1, 2) # Convert to tensor and permute to (N, C, H, W) | |
# resize make sure the shortest side is 336 | |
h, w = video_tensor.shape[2:] | |
scale = max(336 / h, 336 / w) | |
if scale < 1: | |
new_h, new_w = int(h * scale), int(w * scale) | |
video_tensor = T.Resize((new_h, new_w))(video_tensor) | |
video_tensor = video_tensor[::fps].float() | |
depth_tensor = None | |
intrs = None | |
extrs = None | |
data_npz_load = {} | |
# Load and process mask | |
if os.path.exists(mask_path): | |
mask = cv2.imread(mask_path) | |
mask = cv2.resize(mask, (video_tensor.shape[3], video_tensor.shape[2])) | |
mask = mask.sum(axis=-1)>0 | |
else: | |
mask = np.ones_like(video_tensor[0,0].numpy())>0 | |
# Get frame dimensions and create grid points | |
frame_H, frame_W = video_tensor.shape[2:] | |
grid_pts = get_points_on_a_grid(grid_size, (frame_H, frame_W), device="cpu") | |
# Sample mask values at grid points and filter out points where mask=0 | |
if os.path.exists(mask_path): | |
grid_pts_int = grid_pts[0].long() | |
mask_values = mask[grid_pts_int[...,1], grid_pts_int[...,0]] | |
grid_pts = grid_pts[:, mask_values] | |
query_xyt = torch.cat([torch.zeros_like(grid_pts[:, :, :1]), grid_pts], dim=2)[0].numpy() | |
# Run model inference | |
with torch.amp.autocast(device_type="cuda", dtype=torch.bfloat16): | |
( | |
c2w_traj, intrs, point_map, conf_depth, | |
track3d_pred, track2d_pred, vis_pred, conf_pred, video | |
) = model.forward(video_tensor, depth=depth_tensor, | |
intrs=intrs, extrs=extrs, | |
queries=query_xyt, | |
fps=1, full_point=False, iters_track=4, | |
query_no_BA=True, fixed_cam=False, stage=1, | |
support_frame=len(video_tensor)-1, replace_ratio=0.2) | |
# Resize results to avoid too large I/O Burden | |
max_size = 336 | |
h, w = video.shape[2:] | |
scale = min(max_size / h, max_size / w) | |
if scale < 1: | |
new_h, new_w = int(h * scale), int(w * scale) | |
video = T.Resize((new_h, new_w))(video) | |
video_tensor = T.Resize((new_h, new_w))(video_tensor) | |
point_map = T.Resize((new_h, new_w))(point_map) | |
track2d_pred[...,:2] = track2d_pred[...,:2] * scale | |
intrs[:,:2,:] = intrs[:,:2,:] * scale | |
if depth_tensor is not None: | |
depth_tensor = T.Resize((new_h, new_w))(depth_tensor) | |
conf_depth = T.Resize((new_h, new_w))(conf_depth) | |
# Visualize tracks | |
viser.visualize(video=video[None], | |
tracks=track2d_pred[None][...,:2], | |
visibility=vis_pred[None],filename="test") | |
# Save in tapip3d format | |
data_npz_load["coords"] = (torch.einsum("tij,tnj->tni", c2w_traj[:,:3,:3], track3d_pred[:,:,:3].cpu()) + c2w_traj[:,:3,3][:,None,:]).numpy() | |
data_npz_load["extrinsics"] = torch.inverse(c2w_traj).cpu().numpy() | |
data_npz_load["intrinsics"] = intrs.cpu().numpy() | |
data_npz_load["depths"] = point_map[:,2,...].cpu().numpy() | |
data_npz_load["video"] = (video_tensor).cpu().numpy()/255 | |
data_npz_load["visibs"] = vis_pred.cpu().numpy() | |
data_npz_load["confs"] = conf_pred.cpu().numpy() | |
data_npz_load["confs_depth"] = conf_depth.cpu().numpy() | |
np.savez(os.path.join(out_dir, f'result.npz'), **data_npz_load) | |
print(f"Results saved to {out_dir}.\nTo visualize them with tapip3d, run: [bold yellow]python tapip3d_viz.py {out_dir}/result.npz[/bold yellow]") |