xiaoyuxi
gradio_app
a51c6d2
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import os
import numpy as np
import cv2
import torch
import flow_vis
from matplotlib import cm
import torch.nn.functional as F
import torchvision.transforms as transforms
import moviepy
from moviepy.editor import ImageSequenceClip
import matplotlib.pyplot as plt
def read_video_from_path(path):
cap = cv2.VideoCapture(path)
if not cap.isOpened():
print("Error opening video file")
else:
frames = []
while cap.isOpened():
ret, frame = cap.read()
if ret == True:
frames.append(np.array(cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)))
else:
break
cap.release()
return np.stack(frames)
class Visualizer:
def __init__(
self,
save_dir: str = "./results",
grayscale: bool = False,
pad_value: int = 0,
fps: int = 10,
mode: str = "rainbow", # 'cool', 'optical_flow'
linewidth: int = 2,
show_first_frame: int = 10,
tracks_leave_trace: int = 0, # -1 for infinite
):
self.mode = mode
self.save_dir = save_dir
if mode == "rainbow":
self.color_map = cm.get_cmap("gist_rainbow")
elif mode == "cool":
self.color_map = cm.get_cmap(mode)
self.show_first_frame = show_first_frame
self.grayscale = grayscale
self.tracks_leave_trace = tracks_leave_trace
self.pad_value = pad_value
self.linewidth = linewidth
self.fps = fps
def visualize(
self,
video: torch.Tensor, # (B,T,C,H,W)
tracks: torch.Tensor, # (B,T,N,2)
visibility: torch.Tensor = None, # (B, T, N, 1) bool
gt_tracks: torch.Tensor = None, # (B,T,N,2)
segm_mask: torch.Tensor = None, # (B,1,H,W)
filename: str = "video",
writer=None, # tensorboard Summary Writer, used for visualization during training
step: int = 0,
query_frame: int = 0,
save_video: bool = True,
compensate_for_camera_motion: bool = False,
rigid_part = None,
video_depth = None # (B,T,C,H,W)
):
if compensate_for_camera_motion:
assert segm_mask is not None
if segm_mask is not None:
coords = tracks[0, query_frame].round().long()
segm_mask = segm_mask[0, query_frame][coords[:, 1], coords[:, 0]].long()
video = F.pad(
video,
(self.pad_value, self.pad_value, self.pad_value, self.pad_value),
"constant",
255,
)
if video_depth is not None:
video_depth = (video_depth*255).cpu().numpy().astype(np.uint8)
video_depth = ([cv2.applyColorMap(video_depth[0,i,0], cv2.COLORMAP_INFERNO)
for i in range(video_depth.shape[1])])
video_depth = np.stack(video_depth, axis=0)
video_depth = torch.from_numpy(video_depth).permute(0, 3, 1, 2)[None]
tracks = tracks + self.pad_value
if self.grayscale:
transform = transforms.Grayscale()
video = transform(video)
video = video.repeat(1, 1, 3, 1, 1)
res_video = self.draw_tracks_on_video(
video=video,
tracks=tracks,
visibility=visibility,
segm_mask=segm_mask,
gt_tracks=gt_tracks,
query_frame=query_frame,
compensate_for_camera_motion=compensate_for_camera_motion,
rigid_part=rigid_part
)
if save_video:
self.save_video(res_video, filename=filename,
writer=writer, step=step)
if video_depth is not None:
self.save_video(video_depth, filename=filename+"_depth",
writer=writer, step=step)
return res_video
def save_video(self, video, filename, writer=None, step=0):
if writer is not None:
writer.add_video(
f"{filename}_pred_track",
video.to(torch.uint8),
global_step=step,
fps=self.fps,
)
else:
os.makedirs(self.save_dir, exist_ok=True)
wide_list = list(video.unbind(1))
wide_list = [wide[0].permute(1, 2, 0).cpu().numpy() for wide in wide_list]
clip = ImageSequenceClip(wide_list[2:-1], fps=self.fps)
# Write the video file
save_path = os.path.join(self.save_dir, f"{filename}_pred_track.mp4")
clip.write_videofile(save_path, codec="libx264", fps=self.fps, logger=None)
print(f"Video saved to {save_path}")
def draw_tracks_on_video(
self,
video: torch.Tensor,
tracks: torch.Tensor,
visibility: torch.Tensor = None,
segm_mask: torch.Tensor = None,
gt_tracks=None,
query_frame: int = 0,
compensate_for_camera_motion=False,
rigid_part=None,
):
B, T, C, H, W = video.shape
_, _, N, D = tracks.shape
assert D == 2
assert C == 3
video = video[0].permute(0, 2, 3, 1).byte().detach().cpu().numpy() # S, H, W, C
tracks = tracks[0].long().detach().cpu().numpy() # S, N, 2
if gt_tracks is not None:
gt_tracks = gt_tracks.detach().cpu().numpy()
res_video = []
# process input video
for rgb in video:
res_video.append(rgb.copy())
vector_colors = np.zeros((T, N, 3))
if self.mode == "optical_flow":
vector_colors = flow_vis.flow_to_color(tracks - tracks[query_frame][None])
elif segm_mask is None:
if self.mode == "rainbow":
y_min, y_max = (
tracks[query_frame, :, 1].min(),
tracks[query_frame, :, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
color = self.color_map(norm(tracks[query_frame, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with time
for t in range(T):
color = np.array(self.color_map(t / T)[:3])[None] * 255
vector_colors[t] = np.repeat(color, N, axis=0)
else:
if self.mode == "rainbow":
vector_colors[:, segm_mask <= 0, :] = 255
y_min, y_max = (
tracks[0, segm_mask > 0, 1].min(),
tracks[0, segm_mask > 0, 1].max(),
)
norm = plt.Normalize(y_min, y_max)
for n in range(N):
if segm_mask[n] > 0:
color = self.color_map(norm(tracks[0, n, 1]))
color = np.array(color[:3])[None] * 255
vector_colors[:, n] = np.repeat(color, T, axis=0)
else:
# color changes with segm class
segm_mask = segm_mask.cpu()
color = np.zeros((segm_mask.shape[0], 3), dtype=np.float32)
color[segm_mask > 0] = np.array(self.color_map(1.0)[:3]) * 255.0
color[segm_mask <= 0] = np.array(self.color_map(0.0)[:3]) * 255.0
vector_colors = np.repeat(color[None], T, axis=0)
# draw tracks
if self.tracks_leave_trace != 0:
for t in range(1, T):
first_ind = (
max(0, t - self.tracks_leave_trace)
if self.tracks_leave_trace >= 0
else 0
)
curr_tracks = tracks[first_ind : t + 1]
curr_colors = vector_colors[first_ind : t + 1]
if compensate_for_camera_motion:
diff = (
tracks[first_ind : t + 1, segm_mask <= 0]
- tracks[t : t + 1, segm_mask <= 0]
).mean(1)[:, None]
curr_tracks = curr_tracks - diff
curr_tracks = curr_tracks[:, segm_mask > 0]
curr_colors = curr_colors[:, segm_mask > 0]
res_video[t] = self._draw_pred_tracks(
res_video[t],
curr_tracks,
curr_colors,
)
if gt_tracks is not None:
res_video[t] = self._draw_gt_tracks(
res_video[t], gt_tracks[first_ind : t + 1]
)
if rigid_part is not None:
cls_label = torch.unique(rigid_part)
cls_num = len(torch.unique(rigid_part))
# visualize the clustering results
cmap = plt.get_cmap('jet') # get the color mapping
colors = cmap(np.linspace(0, 1, cls_num))
colors = (colors[:, :3] * 255)
color_map = {lable.item(): color for lable, color in zip(cls_label, colors)}
# draw points
for t in range(T):
for i in range(N):
coord = (tracks[t, i, 0], tracks[t, i, 1])
visibile = True
if visibility is not None:
visibile = visibility[0, t, i] > 0.5
if coord[0] != 0 and coord[1] != 0:
if not compensate_for_camera_motion or (
compensate_for_camera_motion and segm_mask[i] > 0
):
if rigid_part is not None:
color = color_map[rigid_part.squeeze()[i].item()]
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
color.tolist(),
thickness=-1 if visibile else 2
-1,
)
else:
cv2.circle(
res_video[t],
coord,
int(self.linewidth * 2),
vector_colors[t, i].tolist(),
thickness=-1 if visibile else 2
-1,
)
# construct the final rgb sequence
if self.show_first_frame > 0:
res_video = [res_video[0]] * self.show_first_frame + res_video[1:]
return torch.from_numpy(np.stack(res_video)).permute(0, 3, 1, 2)[None].byte()
def _draw_pred_tracks(
self,
rgb: np.ndarray, # H x W x 3
tracks: np.ndarray, # T x 2
vector_colors: np.ndarray,
alpha: float = 0.5,
):
T, N, _ = tracks.shape
for s in range(T - 1):
vector_color = vector_colors[s]
original = rgb.copy()
alpha = (s / T) ** 2
for i in range(N):
coord_y = (int(tracks[s, i, 0]), int(tracks[s, i, 1]))
coord_x = (int(tracks[s + 1, i, 0]), int(tracks[s + 1, i, 1]))
if coord_y[0] != 0 and coord_y[1] != 0:
cv2.line(
rgb,
coord_y,
coord_x,
vector_color[i].tolist(),
self.linewidth,
cv2.LINE_AA,
)
if self.tracks_leave_trace > 0:
rgb = cv2.addWeighted(rgb, alpha, original, 1 - alpha, 0)
return rgb
def _draw_gt_tracks(
self,
rgb: np.ndarray, # H x W x 3,
gt_tracks: np.ndarray, # T x 2
):
T, N, _ = gt_tracks.shape
color = np.array((211.0, 0.0, 0.0))
for t in range(T):
for i in range(N):
gt_tracks_i = gt_tracks[t][i]
# draw a red cross
if gt_tracks_i[0] > 0 and gt_tracks_i[1] > 0:
length = self.linewidth * 3
coord_y = (int(gt_tracks_i[0]) + length, int(gt_tracks_i[1]) + length)
coord_x = (int(gt_tracks_i[0]) - length, int(gt_tracks_i[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
coord_y = (int(gt_tracks_i[0]) - length, int(gt_tracks_i[1]) + length)
coord_x = (int(gt_tracks_i[0]) + length, int(gt_tracks_i[1]) - length)
cv2.line(
rgb,
coord_y,
coord_x,
color,
self.linewidth,
cv2.LINE_AA,
)
return rgb