Spaces:
Sleeping
Sleeping
import os | |
import shutil | |
import argparse | |
import emage.mertic | |
from moviepy.tools import verbose_print | |
from omegaconf import OmegaConf | |
import random | |
import numpy as np | |
import json | |
import librosa | |
from datetime import datetime | |
import importlib | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torch.optim import AdamW | |
from torch.utils.data import DataLoader | |
from torch.nn.parallel import DistributedDataParallel as DDP | |
# from torch.utils.tensorboard import SummaryWriter | |
import wandb | |
from diffusers.optimization import get_scheduler | |
from tqdm import tqdm | |
import smplx | |
from moviepy.editor import VideoFileClip, AudioFileClip, ImageSequenceClip | |
import igraph | |
import emage | |
import utils.rotation_conversions as rc | |
from create_graph import path_visualization, graph_pruning, get_motion_reps_tensor | |
def search_path(graph, audio_low_np, audio_high_np, top_k=1, loop_penalty=0.1, search_mode="both"): | |
T = audio_low_np.shape[0] # Total time steps | |
# Initialize the beam with start nodes (nodes with no previous node) | |
start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] | |
beam = [] | |
for node in start_nodes: | |
motion_low = node['motion_low'] # Shape: [C] | |
motion_high = node['motion_high'] # Shape: [C] | |
# cost = np.linalg.norm(audio_low_np[0] - motion_low) + np.linalg.norm(audio_high_np - motion_high) | |
if search_mode == "both": | |
cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) | |
elif search_mode == "high_level": | |
cost = 1 - np.dot(audio_high_np[0], motion_high.T) | |
elif search_mode == "low_level": | |
cost = 1 - np.dot(audio_low_np[0], motion_low.T) | |
sequence = [node] | |
beam.append((cost, sequence)) | |
# Keep only the top_k initial nodes | |
beam.sort(key=lambda x: x[0]) | |
beam = beam[:top_k] | |
# Beam search over time steps | |
for t in range(1, T): | |
new_beam = [] | |
for cost, seq in beam: | |
last_node = seq[-1] | |
neighbor_indices = graph.neighbors(last_node.index, mode='OUT') | |
if not neighbor_indices: | |
continue # No outgoing edges from the last node | |
for idx in neighbor_indices: | |
neighbor = graph.vs[idx] | |
# Check for loops | |
if neighbor in seq: | |
# Apply loop penalty | |
loop_cost = cost + loop_penalty | |
else: | |
loop_cost = cost | |
motion_low = neighbor['motion_low'] # Shape: [C] | |
motion_high = neighbor['motion_high'] # Shape: [C] | |
# cost_increment = np.linalg.norm(audio_low_np[t] - motion_low) + np.linalg.norm(audio_high_np[t] - motion_high) | |
if search_mode == "both": | |
cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) | |
elif search_mode == "high_level": | |
cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) | |
elif search_mode == "low_level": | |
cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) | |
new_cost = loop_cost + cost_increment | |
new_seq = seq + [neighbor] | |
new_beam.append((new_cost, new_seq)) | |
if not new_beam: | |
break # Cannot extend any further | |
# Keep only the top_k sequences | |
new_beam.sort(key=lambda x: x[0]) | |
beam = new_beam[:top_k] | |
# Extract paths and continuity information | |
path_list = [] | |
is_continue_list = [] | |
for cost, seq in beam: | |
path_list.append(seq) | |
print("Cost: ", cost, "path", [node.index for node in seq]) | |
is_continue = [] | |
for i in range(len(seq) - 1): | |
edge_id = graph.get_eid(seq[i].index, seq[i + 1].index) | |
is_cont = graph.es[edge_id]['is_continue'] | |
is_continue.append(is_cont) | |
is_continue_list.append(is_continue) | |
return path_list, is_continue_list | |
def search_path_dp(graph, audio_low_np, audio_high_np, loop_penalty=0.01, top_k=1, search_mode="both", continue_penalty=0.01): | |
T = audio_low_np.shape[0] # Total time steps | |
N = len(graph.vs) # Total number of nodes in the graph | |
# Initialize DP tables | |
min_cost = [{} for _ in range(T)] # min_cost[t][node.index] = (cost, predecessor_index, non_continue_count) | |
visited_nodes = [{} for _ in range(T)] # visited_nodes[t][node.index] = dict of node visit counts | |
# Initialize the first time step | |
start_nodes = [v for v in graph.vs if v['previous'] is None or v['previous'] == -1] | |
for node in start_nodes: | |
motion_low = node['motion_low'] # Shape: [C] | |
motion_high = node['motion_high'] # Shape: [C] | |
# Cost using cosine similarity | |
if search_mode == "both": | |
cost = 2 - (np.dot(audio_low_np[0], motion_low.T) + np.dot(audio_high_np[0], motion_high.T)) | |
elif search_mode == "high_level": | |
cost = 1 - np.dot(audio_high_np[0], motion_high.T) | |
elif search_mode == "low_level": | |
cost = 1 - np.dot(audio_low_np[0], motion_low.T) | |
min_cost[0][node.index] = (cost, None, 0) # Initialize with no predecessor and 0 non-continue count | |
visited_nodes[0][node.index] = {node.index: 1} # Initialize visit count as a dictionary | |
# DP over time steps | |
for t in range(1, T): | |
for node in graph.vs: | |
node_index = node.index | |
min_cost_t = float('inf') | |
best_predecessor = None | |
best_visited = None | |
best_non_continue_count = 0 | |
# Incoming edges to the current node | |
incoming_edges = graph.es.select(_to=node_index) | |
for edge in incoming_edges: | |
prev_node_index = edge.source | |
prev_node = graph.vs[prev_node_index] | |
if prev_node_index in min_cost[t-1]: | |
prev_cost, _, prev_non_continue_count = min_cost[t-1][prev_node_index] | |
prev_visited = visited_nodes[t-1][prev_node_index] | |
# Loop punishment | |
if node_index in prev_visited: | |
loop_time = prev_visited[node_index] # Get the count of previous visits | |
loop_cost = prev_cost + loop_penalty * np.exp(loop_time) # Apply exponential penalty | |
new_visited = prev_visited.copy() | |
new_visited[node_index] = loop_time + 1 # Increment visit count | |
else: | |
loop_cost = prev_cost | |
new_visited = prev_visited.copy() | |
new_visited[node_index] = 1 # Initialize visit count for the new node | |
motion_low = node['motion_low'] # Shape: [C] | |
motion_high = node['motion_high'] # Shape: [C] | |
if search_mode == "both": | |
cost_increment = 2 - (np.dot(audio_low_np[t], motion_low.T) + np.dot(audio_high_np[t], motion_high.T)) | |
elif search_mode == "high_level": | |
cost_increment = 1 - np.dot(audio_high_np[t], motion_high.T) | |
elif search_mode == "low_level": | |
cost_increment = 1 - np.dot(audio_low_np[t], motion_low.T) | |
# Check if the edge is "is_continue" | |
edge_id = edge.index | |
is_continue = graph.es[edge_id]['is_continue'] | |
if not is_continue: | |
non_continue_count = prev_non_continue_count + 1 # Increment the count of non-continue edges | |
else: | |
non_continue_count = prev_non_continue_count | |
# Apply the penalty based on the square of the number of non-continuous edges | |
continue_penalty_cost = continue_penalty * non_continue_count | |
total_cost = loop_cost + cost_increment + continue_penalty_cost | |
if total_cost < min_cost_t: | |
min_cost_t = total_cost | |
best_predecessor = prev_node_index | |
best_visited = new_visited | |
best_non_continue_count = non_continue_count | |
if best_predecessor is not None: | |
min_cost[t][node_index] = (min_cost_t, best_predecessor, best_non_continue_count) | |
visited_nodes[t][node_index] = best_visited # Store the new visit count dictionary | |
# Find the node with the minimal cost at the last time step | |
final_min_cost = float('inf') | |
final_node_index = None | |
for node_index, (cost, _, _) in min_cost[T-1].items(): | |
if cost < final_min_cost: | |
final_min_cost = cost | |
final_node_index = node_index | |
if final_node_index is None: | |
print("No valid path found.") | |
return [], [] | |
# Backtrack to reconstruct the optimal path | |
optimal_path_indices = [] | |
current_node_index = final_node_index | |
for t in range(T-1, -1, -1): | |
optimal_path_indices.append(current_node_index) | |
_, predecessor, _ = min_cost[t][current_node_index] | |
current_node_index = predecessor if predecessor is not None else None | |
optimal_path_indices = optimal_path_indices[::-1] # Reverse to get correct order | |
optimal_path = [graph.vs[idx] for idx in optimal_path_indices] | |
# Extract continuity information | |
is_continue = [] | |
for i in range(len(optimal_path) - 1): | |
edge_id = graph.get_eid(optimal_path[i].index, optimal_path[i + 1].index) | |
is_cont = graph.es[edge_id]['is_continue'] | |
is_continue.append(is_cont) | |
print("Optimal Cost: ", final_min_cost, "Path: ", optimal_path_indices) | |
return [optimal_path], [is_continue] | |
# from torch.cuda.amp import autocast, GradScaler | |
# from torch.nn.utils import clip_grad_norm_ | |
# # Initialize GradScaler | |
# scaler = GradScaler() | |
def train_val_fn(batch, model, device, mode="train", optimizer=None, lr_scheduler=None, max_grad_norm=1.0, **kwargs): | |
if mode == "train": | |
model.train() | |
torch.set_grad_enabled(True) | |
optimizer.zero_grad() | |
else: | |
model.eval() | |
torch.set_grad_enabled(False) | |
cached_rep15d = batch["cached_rep15d"].to(device) | |
cached_audio_low = batch["cached_audio_low"].to(device) | |
cached_audio_high = batch["cached_audio_high"].to(device) | |
bert_time_aligned = batch["bert_time_aligned"].to(device) | |
cached_audio_high = torch.cat([cached_audio_high, bert_time_aligned], dim=-1) | |
audio_tensor = batch["audio_tensor"].to(device) | |
# with autocast(): # Mixed precision context | |
model_out = model(cached_rep15d=cached_rep15d, cached_audio_low=cached_audio_low, cached_audio_high=cached_audio_high, in_audio=audio_tensor) | |
audio_lower = model_out["audio_low"] | |
motion_lower = model_out["motion_low"] | |
audio_hihger_cls = model_out["audio_cls"] | |
motion_higher_cls = model_out["motion_cls"] | |
high_loss = model_out["high_level_loss"] | |
low_infonce, low_acc = model_out["low_level_loss"] | |
loss_dict = { | |
"low_cosine": low_infonce, | |
"high_infonce": high_loss | |
} | |
loss = sum(loss_dict.values()) | |
loss_dict["loss"] = loss | |
loss_dict["low_acc"] = low_acc | |
loss_dict["acc"] = compute_average_precision(audio_hihger_cls, motion_higher_cls) | |
if mode == "train": | |
# Use GradScaler for backward pass | |
# scaler.scale(loss).backward() | |
# Clip gradients to the maximum norm | |
# scaler.unscale_(optimizer) # Unscale gradients before clipping | |
# clip_grad_norm_(model.parameters(), max_grad_norm) | |
# Step the optimizer | |
# scaler.step(optimizer) | |
# scaler.update() | |
loss.backward() | |
optimizer.step() | |
lr_scheduler.step() | |
return loss_dict | |
def test_fn(model, device, smplx_model, iteration, fgd_fn, srgr_fn, bc_fn, l1div_fn, candidate_json_path, test_path, cfg, **kwargs): | |
torch.set_grad_enabled(False) | |
pool_path = "./datasets/oliver_test/show-oliver-test.pkl" | |
graph = igraph.Graph.Read_Pickle(fname=pool_path) | |
save_dir = os.path.join(test_path, f"retrieved_motions_{iteration}") | |
os.makedirs(save_dir, exist_ok=True) | |
actual_model = model.module if isinstance(model, torch.nn.parallel.DistributedDataParallel) else model | |
actual_model.eval() | |
with open(candidate_json_path, 'r') as f: | |
candidate_data = json.load(f) | |
all_motions = {} | |
for i, node in enumerate(graph.vs): | |
if all_motions.get(node["name"]) is None: | |
all_motions[node["name"]] = [node["axis_angle"].reshape(-1)] | |
else: | |
all_motions[node["name"]].append(node["axis_angle"].reshape(-1)) | |
for k, v in all_motions.items(): | |
all_motions[k] = np.stack(v) # T, J*3 | |
window_size = cfg.data.pose_length | |
motion_high_all = [] | |
motion_low_all = [] | |
for k, v in all_motions.items(): | |
motion_tensor = torch.from_numpy(v).float().to(device).unsqueeze(0) | |
_, t, _ = motion_tensor.shape | |
num_chunks = t // window_size | |
motion_high_list = [] | |
motion_low_list = [] | |
for i in range(num_chunks): | |
start_idx = i * window_size | |
end_idx = start_idx + window_size | |
motion_slice = motion_tensor[:, start_idx:end_idx, :] | |
motion_features = actual_model.get_motion_features(motion_slice) | |
motion_high = motion_features["motion_high_weight"].cpu().numpy() | |
motion_low = motion_features["motion_low"].cpu().numpy() | |
motion_high_list.append(motion_high[0]) | |
motion_low_list.append(motion_low[0]) | |
remain_length = t % window_size | |
if remain_length > 0: | |
start_idx = t - window_size | |
motion_slice = motion_tensor[:, start_idx:, :] | |
motion_features = actual_model.get_motion_features(motion_slice) | |
motion_high = motion_features["motion_high_weight"].cpu().numpy() | |
motion_low = motion_features["motion_low"].cpu().numpy() | |
motion_high_list.append(motion_high[0][-remain_length:]) | |
motion_low_list.append(motion_low[0][-remain_length:]) | |
motion_high_all.append(np.concatenate(motion_high_list, axis=0)) | |
motion_low_all.append(np.concatenate(motion_low_list, axis=0)) | |
motion_high_all = np.concatenate(motion_high_all, axis=0) | |
motion_low_all = np.concatenate(motion_low_all, axis=0) | |
# print(motion_high_all.shape, motion_low_all.shape) | |
motion_low_all = motion_low_all / np.linalg.norm(motion_low_all, axis=1, keepdims=True) | |
motion_high_all = motion_high_all / np.linalg.norm(motion_high_all, axis=1, keepdims=True) | |
assert motion_high_all.shape[0] == len(graph.vs) | |
assert motion_low_all.shape[0] == len(graph.vs) | |
for i, node in enumerate(graph.vs): | |
node["motion_high"] = motion_high_all[i] | |
node["motion_low"] = motion_low_all[i] | |
graph = graph_pruning(graph) | |
for idx, pair in enumerate(tqdm(candidate_data, desc="Testing")): | |
gt_motion = np.load(pair["motion_path"] + ".npz", allow_pickle=True)["poses"] | |
target_length = gt_motion.shape[0] | |
audio_path = pair["audio_path"] + ".wav" | |
audio_waveform, sr = librosa.load(audio_path) | |
audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) | |
audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) | |
window_size = int(cfg.data.audio_sr * (cfg.data.pose_length / 30)) | |
_, t = audio_tensor.shape | |
num_chunks = t // window_size | |
audio_low_list = [] | |
audio_high_list = [] | |
for i in range(num_chunks): | |
start_idx = i * window_size | |
end_idx = start_idx + window_size | |
# print(start_idx, end_idx, window_size) | |
audio_slice = audio_tensor[:, start_idx:end_idx] | |
model_out_candidates = actual_model.get_audio_features(audio_slice) | |
audio_low = model_out_candidates["audio_low"] | |
audio_high = model_out_candidates["audio_high_weight"] | |
audio_low = F.normalize(audio_low, dim=2)[0].cpu().numpy() | |
audio_high = F.normalize(audio_high, dim=2)[0].cpu().numpy() | |
audio_low_list.append(audio_low) | |
audio_high_list.append(audio_high) | |
# print(audio_low.shape, audio_high.shape) | |
remain_length = t % window_size | |
if remain_length > 0: | |
start_idx = t - window_size | |
audio_slice = audio_tensor[:, start_idx:] | |
model_out_candidates = actual_model.get_audio_features(audio_slice) | |
audio_low = model_out_candidates["audio_low"] | |
audio_high = model_out_candidates["audio_high_weight"] | |
gap = target_length - np.concatenate(audio_low_list, axis=0).shape[1] | |
audio_low = F.normalize(audio_low, dim=2)[0][-gap:].cpu().numpy() | |
audio_high = F.normalize(audio_high, dim=2)[0][-gap:].cpu().numpy() | |
# print(audio_low.shape, audio_high.shape) | |
audio_low_list.append(audio_low) | |
audio_high_list.append(audio_high) | |
audio_low_all = np.concatenate(audio_low_list, axis=0) | |
audio_high_all = np.concatenate(audio_high_list, axis=0) | |
# search the path with audio low features [T, c] and audio high features [T, c] | |
path_list, is_continue_list = search_path(graph, audio_low_all, audio_high_all, top_k=1, search_mode="high_level") | |
res_motion = [] | |
counter = 0 | |
for path, is_continue in zip(path_list, is_continue_list): | |
res_motion_current = path_visualization( | |
graph, path, is_continue, os.path.join(save_dir, f"audio_{idx}_retri_{counter}.mp4"), audio_path=audio_path, return_motion=True, verbose_continue=True | |
) | |
res_motion.append(res_motion_current) | |
np.savez(os.path.join(save_dir, f"audio_{idx}_retri_{counter}.npz"), motion=res_motion_current) | |
counter += 1 | |
metrics = {} | |
counts = {"top1": 0, "top3": 0, "top10": 0} | |
fgd_fn.reset() | |
l1div_fn.reset() | |
bc_fn.reset() | |
srgr_fn.reset() | |
for idx, pair in enumerate(tqdm(candidate_data, desc="Evaluating")): | |
gt_motion = np.load(pair["motion_path"] + ".npz", allow_pickle=True)["poses"] | |
audio_path = pair["audio_path"] + ".wav" | |
gt_motion_tensor = torch.from_numpy(gt_motion).float().to(device).unsqueeze(0) | |
bs, n, _ = gt_motion_tensor.size() | |
audio_waveform, sr = librosa.load(audio_path, sr=None) | |
audio_waveform = librosa.resample(audio_waveform, orig_sr=sr, target_sr=cfg.data.audio_sr) | |
audio_tensor = torch.from_numpy(audio_waveform).float().to(device).unsqueeze(0) | |
top1_path = os.path.join(save_dir, f"audio_{idx}_retri_0.npz") | |
top1_motion = np.load(top1_path, allow_pickle=True)["motion"] # T 165 | |
top1_motion_tensor = torch.from_numpy(top1_motion).float().to(device).unsqueeze(0) # Add bs, to 1 T 165 | |
gt_vertex = smplx_model( | |
betas=torch.zeros(bs*n, 300).to(device), | |
transl=torch.zeros(bs*n, 3).to(device), | |
expression=torch.zeros(bs*n, 100).to(device), | |
jaw_pose=torch.zeros(bs*n, 3).to(device), | |
global_orient=torch.zeros(bs*n, 3).to(device), | |
body_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 3:21*3+3], | |
left_hand_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 25*3:40*3], | |
right_hand_pose=gt_motion_tensor.reshape(bs*n, 55*3)[:, 40*3:55*3], | |
return_joints=True, | |
leye_pose=torch.zeros(bs*n, 3).to(device), | |
reye_pose=torch.zeros(bs*n, 3).to(device), | |
)["joints"].detach().cpu().numpy().reshape(bs, n, 127*3)[0, :, :55*3] | |
top1_vertex = smplx_model( | |
betas=torch.zeros(bs*n, 300).to(device), | |
transl=torch.zeros(bs*n, 3).to(device), | |
expression=torch.zeros(bs*n, 100).to(device), | |
jaw_pose=torch.zeros(bs*n, 3).to(device), | |
global_orient=torch.zeros(bs*n, 3).to(device), | |
body_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 3:21*3+3], | |
left_hand_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 25*3:40*3], | |
right_hand_pose=top1_motion_tensor.reshape(bs*n, 55*3)[:, 40*3:55*3], | |
return_joints=True, | |
leye_pose=torch.zeros(bs*n, 3).to(device), | |
reye_pose=torch.zeros(bs*n, 3).to(device), | |
)["joints"].detach().cpu().numpy().reshape(bs, n, 127*3)[0, :, :55*3] | |
l1div_fn.run(top1_vertex) | |
# print(audio_waveform.shape, top1_vertex.shape) | |
onset_bt = bc_fn.load_audio(audio_waveform, t_start=None, without_file=True, sr_audio=cfg.data.audio_sr) | |
beat_vel = bc_fn.load_pose(top1_vertex, 0, n, pose_fps = 30, without_file=True) | |
# print(n) | |
# print(onset_bt) | |
# print(beat_vel) | |
bc_fn.calculate_align(onset_bt, beat_vel, 30) | |
srgr_fn.run(gt_vertex, top1_vertex) | |
gt_motion_tensor = rc.axis_angle_to_matrix(gt_motion_tensor.reshape(1, n, 55, 3)) | |
gt_motion_tensor = rc.matrix_to_rotation_6d(gt_motion_tensor).reshape(1, n, 55*6) | |
top1_motion_tensor = rc.axis_angle_to_matrix(top1_motion_tensor.reshape(1, n, 55, 3)) | |
top1_motion_tensor = rc.matrix_to_rotation_6d(top1_motion_tensor).reshape(1, n, 55*6) | |
remain = n % 32 | |
if remain != 0: | |
gt_motion_tensor = gt_motion_tensor[:, :n-remain] | |
top1_motion_tensor = top1_motion_tensor[:, :n-remain] | |
# print(gt_motion_tensor.shape, top1_motion_tensor.shape) | |
fgd_fn.update(gt_motion_tensor, top1_motion_tensor) | |
metrics["fgd_top1"] = fgd_fn.compute() | |
metrics["l1_top1"] = l1div_fn.avg() | |
metrics["bc_top1"] = bc_fn.avg() | |
metrics["srgr_top1"] = srgr_fn.avg() | |
print(f"Test Metrics at Iteration {iteration}:") | |
for key, value in metrics.items(): | |
print(f"{key}: {value:.6f}") | |
return metrics | |
def compute_average_precision(feature1, feature2): | |
# Normalize the features | |
feature1 = F.normalize(feature1, dim=1) | |
feature2 = F.normalize(feature2, dim=1) | |
# Compute the similarity matrix | |
similarity_matrix = torch.matmul(feature1, feature2.t()) | |
# Get the top-1 predicted indices for each feature in feature1 | |
top1_indices = torch.argmax(similarity_matrix, dim=1) | |
# Generate ground truth labels (diagonal indices) | |
batch_size = feature1.size(0) | |
ground_truth = torch.arange(batch_size, device=feature1.device) | |
# Compute the accuracy (True if the top-1 index matches the ground truth) | |
correct_predictions = (top1_indices == ground_truth).float() | |
# Compute average precision | |
average_precision = correct_predictions.mean() | |
return average_precision | |
class CosineSimilarityLoss(nn.Module): | |
def __init__(self): | |
super(CosineSimilarityLoss, self).__init__() | |
self.cosine_similarity = nn.CosineSimilarity(dim=2) | |
def forward(self, output1, output2): | |
# Calculate cosine similarity | |
cosine_sim = self.cosine_similarity(output1, output2) | |
# Loss is 1 minus the average cosine similarity | |
return 1 - cosine_sim.mean() | |
class InfoNCELossCross(nn.Module): | |
def __init__(self, temperature=0.1): | |
super(InfoNCELossCross, self).__init__() | |
self.temperature = temperature | |
self.criterion = nn.CrossEntropyLoss() | |
def forward(self, feature1, feature2): | |
""" | |
Args: | |
feature1: tensor of shape (batch_size, feature_dim) | |
feature2: tensor of shape (batch_size, feature_dim) | |
where each corresponding index in feature1 and feature2 is a positive pair, | |
and all other combinations are negative pairs. | |
""" | |
batch_size = feature1.size(0) | |
# Normalize feature vectors | |
feature1 = F.normalize(feature1, dim=1) | |
feature2 = F.normalize(feature2, dim=1) | |
# Compute similarity matrix between feature1 and feature2 | |
similarity_matrix = torch.matmul(feature1, feature2.t()) / self.temperature | |
# Labels for each element in feature1 are the indices of their matching pairs in feature2 | |
labels = torch.arange(batch_size, device=feature1.device) | |
# Cross entropy loss for each positive pair with all corresponding negatives | |
loss = self.criterion(similarity_matrix, labels) | |
return loss | |
class LocalContrastiveLoss(nn.Module): | |
def __init__(self, temperature=0.1): | |
super(LocalContrastiveLoss, self).__init__() | |
self.temperature = temperature | |
def forward(self, motion_feature, audio_feature, learned_temp=None): | |
if learned_temp is not None: | |
temperature = learned_temp | |
else: | |
temperature = self.temperature | |
batch_size, T, _ = motion_feature.size() | |
assert len(motion_feature.shape) == 3 | |
motion_feature = F.normalize(motion_feature, dim=2) | |
audio_feature = F.normalize(audio_feature, dim=2) | |
motion_to_audio_loss = 0 | |
audio_to_motion_loss = 0 | |
motion_to_audio_correct = 0 | |
audio_to_motion_correct = 0 | |
# First pass: motion to audio | |
for t in range(T): | |
motion_feature_t = motion_feature[:, t, :] # (bs, c) | |
# Positive pair range for motion | |
start = max(0, t - 4) | |
end = min(T, t + 4) | |
positive_audio_feature = audio_feature[:, start:end, :] # (bs, pos_range, c) | |
# Negative pair range for motion | |
left_end = start | |
left_start = max(0, left_end - 4 * 3) | |
right_start = end | |
right_end = min(T, right_start + 4 * 3) | |
negative_audio_feature = torch.cat( | |
[audio_feature[:, left_start:left_end, :], audio_feature[:, right_start:right_end, :]], | |
dim=1 | |
) # (bs, neg_range, c) | |
# Concatenate positive and negative samples | |
combined_audio_feature = torch.cat([positive_audio_feature, negative_audio_feature], dim=1) # (bs, pos_range + neg_range, c) | |
# Compute similarity scores | |
logits = torch.matmul(motion_feature_t.unsqueeze(1), combined_audio_feature.transpose(1, 2)) / temperature # (bs, 1, pos_range + neg_range) | |
logits = logits.squeeze(1) # (bs, pos_range + neg_range) | |
# Compute InfoNCE loss | |
positive_scores = logits[:, :positive_audio_feature.size(1)] | |
loss_t = -positive_scores.logsumexp(dim=1) + torch.logsumexp(logits, dim=1) | |
motion_to_audio_loss += loss_t.mean() | |
# Compute accuracy | |
max_indices = torch.argmax(logits, dim=1) | |
correct_mask = (max_indices < positive_audio_feature.size(1)).float() # Check if indices are within the range of positive samples | |
motion_to_audio_correct += correct_mask.sum() | |
# Second pass: audio to motion | |
for t in range(T): | |
audio_feature_t = audio_feature[:, t, :] # (bs, c) | |
# Positive pair range for audio | |
start = max(0, t - 4) | |
end = min(T, t + 4) | |
positive_motion_feature = motion_feature[:, start:end, :] # (bs, pos_range, c) | |
# Negative pair range for audio | |
left_end = start | |
left_start = max(0, left_end - 4 * 3) | |
right_start = end | |
right_end = min(T, right_start + 4 * 3) | |
negative_motion_feature = torch.cat( | |
[motion_feature[:, left_start:left_end, :], motion_feature[:, right_start:right_end, :]], | |
dim=1 | |
) # (bs, neg_range, c) | |
# Concatenate positive and negative samples | |
combined_motion_feature = torch.cat([positive_motion_feature, negative_motion_feature], dim=1) # (bs, pos_range + neg_range, c) | |
# Compute similarity scores | |
logits = torch.matmul(audio_feature_t.unsqueeze(1), combined_motion_feature.transpose(1, 2)) / temperature # (bs, 1, pos_range + neg_range) | |
logits = logits.squeeze(1) # (bs, pos_range + neg_range) | |
# Compute InfoNCE loss | |
positive_scores = logits[:, :positive_motion_feature.size(1)] | |
loss_t = -positive_scores.logsumexp(dim=1) + torch.logsumexp(logits, dim=1) | |
audio_to_motion_loss += loss_t.mean() | |
# Compute accuracy | |
max_indices = torch.argmax(logits, dim=1) | |
correct_mask = (max_indices < positive_motion_feature.size(1)).float() # Check if indices are within the range of positive samples | |
audio_to_motion_correct += correct_mask.sum() | |
# Average the two losses | |
final_loss = (motion_to_audio_loss + audio_to_motion_loss) / (2 * T) | |
# Compute final accuracy | |
total_correct = (motion_to_audio_correct + audio_to_motion_correct) / (2 * T * batch_size) | |
return final_loss, total_correct | |
class InfoNCELoss(nn.Module): | |
def __init__(self, temperature=0.1): | |
super(InfoNCELoss, self).__init__() | |
self.temperature = temperature | |
def forward(self, feature1, feature2, learned_temp=None): | |
batch_size = feature1.size(0) | |
assert len(feature1.shape) == 2 | |
if learned_temp is not None: | |
temperature = learned_temp | |
else: | |
temperature = self.temperature | |
# Normalize feature vectors | |
feature1 = F.normalize(feature1, dim=1) | |
feature2 = F.normalize(feature2, dim=1) | |
# Compute similarity matrix between feature1 and feature2 | |
similarity_matrix = torch.matmul(feature1, feature2.t()) / temperature | |
# Extract positive similarities (diagonal elements) | |
positive_similarities = torch.diag(similarity_matrix) | |
# Compute the denominator using logsumexp for numerical stability | |
denominator = torch.logsumexp(similarity_matrix, dim=1) | |
# Compute the InfoNCE loss | |
loss = - (positive_similarities - denominator).mean() | |
return loss | |
def main(cfg): | |
if "LOCAL_RANK" in os.environ: | |
local_rank = int(os.environ["LOCAL_RANK"]) | |
else: | |
local_rank = 0 | |
torch.cuda.set_device(local_rank) | |
device = torch.device("cuda", local_rank) | |
torch.distributed.init_process_group(backend="nccl") | |
seed_everything(cfg.seed) | |
experiment_ckpt_dir = experiment_log_dir = os.path.join(cfg.output_dir, cfg.exp_name) | |
smplx_model = smplx.create( | |
"./emage/smplx_models/", | |
model_type='smplx', | |
gender='NEUTRAL_2020', | |
use_face_contour=False, | |
num_betas=300, | |
num_expression_coeffs=100, | |
ext='npz', | |
use_pca=False, | |
).to(device).eval() | |
model = init_class(cfg.model.name_pyfile, cfg.model.class_name, cfg).cuda() | |
for param in model.parameters(): | |
param.requires_grad = True | |
# freeze wav2vec2 | |
for param in model.audio_encoder.parameters(): | |
param.requires_grad = False | |
model.smplx_model = smplx_model | |
model.get_motion_reps = get_motion_reps_tensor | |
model.high_level_loss_fn = InfoNCELoss() | |
model.low_level_loss_fn = LocalContrastiveLoss() | |
model = DDP( | |
model, | |
device_ids=[local_rank], | |
output_device=local_rank, | |
find_unused_parameters=True, | |
) | |
if cfg.solver.use_8bit_adam: | |
try: | |
import bitsandbytes as bnb | |
except ImportError: | |
raise ImportError( | |
"Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes`" | |
) | |
optimizer_cls = bnb.optim.AdamW8bit | |
print("using 8 bit") | |
else: | |
optimizer_cls = torch.optim.AdamW | |
optimizer = optimizer_cls(filter(lambda p: p.requires_grad, model.parameters()), lr=cfg.solver.learning_rate, | |
betas=(cfg.solver.adam_beta1, cfg.solver.adam_beta2), | |
weight_decay=cfg.solver.adam_weight_decay, | |
eps=cfg.solver.adam_epsilon,) | |
lr_scheduler = get_scheduler( | |
cfg.solver.lr_scheduler, | |
optimizer=optimizer, | |
num_warmup_steps=cfg.solver.lr_warmup_steps | |
* cfg.solver.gradient_accumulation_steps, | |
num_training_steps=cfg.solver.max_train_steps | |
* cfg.solver.gradient_accumulation_steps, | |
) | |
loss_cosine = CosineSimilarityLoss().to(device) | |
loss_mse = nn.MSELoss().to(device) | |
loss_l1 = nn.L1Loss().to(device) | |
loss_infonce = InfoNCELossCross().to(device) | |
loss_fn_dict = { | |
"loss_cosine": loss_cosine, | |
"loss_mse": loss_mse, | |
"loss_l1": loss_l1, | |
"loss_infonce": loss_infonce, | |
} | |
fgd_fn = emage.mertic.FGD(download_path="./emage/") | |
srgr_fn = emage.mertic.SRGR(threshold=0.3, joints=55, joint_dim=3) | |
bc_fn = emage.mertic.BC(download_path="./emage/", sigma=0.5, order=7) | |
l1div_fn = emage.mertic.L1div() | |
train_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='train') | |
test_dataset = init_class(cfg.data.name_pyfile, cfg.data.class_name, cfg, split='test') | |
train_sampler = torch.utils.data.distributed.DistributedSampler(train_dataset) | |
train_loader = DataLoader(train_dataset, batch_size=cfg.data.train_bs, sampler=train_sampler, drop_last=True, num_workers=4) | |
test_sampler = torch.utils.data.distributed.DistributedSampler(test_dataset) | |
test_loader = DataLoader(test_dataset, batch_size=256, sampler=test_sampler, drop_last=False, num_workers=4) | |
if local_rank == 0: | |
run_time = datetime.now().strftime("%Y%m%d-%H%M") | |
wandb.init( | |
project=cfg.wandb_project, | |
name=cfg.exp_name + "_" + run_time, | |
entity=cfg.wandb_entity, | |
dir=cfg.wandb_log_dir, | |
config=OmegaConf.to_container(cfg) # Pass config directly during initialization | |
) | |
else: | |
writer = None | |
num_epochs = cfg.solver.max_train_steps // len(train_loader) + 1 | |
iteration = 0 | |
val_best = {} | |
test_best = {} | |
# checkpoint_path = "/content/drive/MyDrive/005_Weights/baseline_high_env0/checkpoint_3800/ckpt.pth" | |
# checkpoint = torch.load(checkpoint_path) | |
# state_dict = checkpoint['model_state_dict'] | |
# #new_state_dict = {k.replace('module.', ''): v for k, v in state_dict.items()} | |
# model.load_state_dict(state_dict, strict=False) | |
# optimizer.load_state_dict(checkpoint['optimizer_state_dict']) | |
# lr_scheduler.load_state_dict(checkpoint['lr_scheduler_state_dict']) | |
# iteration = checkpoint["iteration"] | |
for epoch in range(num_epochs): | |
train_sampler.set_epoch(epoch) | |
for i, batch in enumerate(train_loader): | |
loss_dict = train_val_fn( | |
batch, model, device, mode="train", optimizer=optimizer, lr_scheduler=lr_scheduler, | |
loss_fn_dict=loss_fn_dict | |
) | |
if local_rank == 0 and iteration % cfg.log_period == 0: | |
for key, value in loss_dict.items(): | |
# writer.add_scalar(f"train/{key}", value, iteration) | |
wandb.log({f"train/{key}": value}, step=iteration) | |
loss_message = ", ".join([f"{k}: {v:.6f}" for k, v in loss_dict.items()]) | |
print(f"Epoch {epoch} [{i}/{len(train_loader)}] - {loss_message}") | |
if local_rank == 0 and iteration % cfg.validation.val_loss_steps == 0: | |
val_loss_dict = {} | |
val_batches = 0 | |
for batch in tqdm(test_loader): | |
loss_dict = train_val_fn( | |
batch, model, device, mode="val", optimizer=optimizer, lr_scheduler=lr_scheduler, | |
loss_fn_dict=loss_fn_dict | |
) | |
for k, v in loss_dict.items(): | |
if k not in val_loss_dict: | |
val_loss_dict[k] = 0 | |
val_loss_dict[k] += v.item() # Convert to float for accumulation | |
val_batches += 1 | |
if val_batches == 10: | |
break | |
val_loss_mean_dict = {k: v / val_batches for k, v in val_loss_dict.items()} | |
for k, v in val_loss_mean_dict.items(): | |
if k not in val_best or v > val_best[k]["value"]: | |
val_best[k] = {"value": v, "iteration": iteration} | |
if "acc" in k: | |
checkpoint_path = os.path.join(experiment_ckpt_dir, f"ckpt_{k}") | |
os.makedirs(checkpoint_path, exist_ok=True) | |
torch.save({ | |
'iteration': iteration, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'lr_scheduler_state_dict': lr_scheduler.state_dict(), | |
}, os.path.join(checkpoint_path, "ckpt.pth")) | |
print(f"Val [{iteration}] - {k}: {v:.6f} (best: {val_best[k]['value']:.6f} at {val_best[k]['iteration']})") | |
# writer.add_scalar(f"val/{k}", v, iteration) | |
wandb.log({f"val/{k}": v}, step=iteration) | |
checkpoint_path = os.path.join(experiment_ckpt_dir, f"checkpoint_{iteration}") | |
os.makedirs(checkpoint_path, exist_ok=True) | |
torch.save({ | |
'iteration': iteration, | |
'model_state_dict': model.state_dict(), | |
'optimizer_state_dict': optimizer.state_dict(), | |
'lr_scheduler_state_dict': lr_scheduler.state_dict(), | |
}, os.path.join(checkpoint_path, "ckpt.pth")) | |
checkpoints = [d for d in os.listdir(experiment_ckpt_dir) if os.path.isdir(os.path.join(experiment_ckpt_dir, d)) and d.startswith("checkpoint_")] | |
checkpoints.sort(key=lambda x: int(x.split("_")[1])) | |
if len(checkpoints) > 3: | |
for ckpt_to_delete in checkpoints[:-3]: | |
shutil.rmtree(os.path.join(experiment_ckpt_dir, ckpt_to_delete)) | |
# if local_rank == 0 and iteration % cfg.validation.validation_steps == 0: | |
# test_path = os.path.join(experiment_ckpt_dir, f"test_{iteration}") | |
# os.makedirs(test_path, exist_ok=True) | |
# test_mertic_dict = test_fn(model, device, smplx_model, iteration, fgd_fn, srgr_fn, bc_fn, l1div_fn, cfg.data.test_meta_paths, test_path, cfg) | |
# for k, v in test_mertic_dict.items(): | |
# if k not in test_best or v < test_best[k]["value"]: | |
# test_best[k] = {"value": v, "iteration": iteration} | |
# print(f"Test [{iteration}] - {k}: {v:.6f} (best: {test_best[k]['value']:.6f} at {test_best[k]['iteration']})") | |
# # writer.add_scalar(f"test/{k}", v, iteration) | |
# wandb.log({f"test/{k}": v}, step=iteration) | |
# video_for_log = [] | |
# video_res_path = os.path.join(test_path, f"retrieved_motions_{iteration}") | |
# for mp4_file in os.listdir(video_res_path): | |
# if mp4_file.endswith(".mp4"): | |
# # print(mp4_file) | |
# file_path = os.path.join(video_res_path, mp4_file) | |
# log_video = wandb.Video(file_path, caption=f"{iteration:06d}-{mp4_file}", format="mp4") | |
# video_for_log.append(log_video) | |
# wandb.log( | |
# {"test/videos": video_for_log}, | |
# step=iteration | |
# ) | |
# visualize_fn(test_path) | |
iteration += 1 | |
if local_rank == 0: | |
writer.close() | |
torch.distributed.destroy_process_group() | |
def init_class(module_name, class_name, config, **kwargs): | |
module = importlib.import_module(module_name) | |
model_class = getattr(module, class_name) | |
instance = model_class(config, **kwargs) | |
return instance | |
def seed_everything(seed): | |
random.seed(seed) | |
np.random.seed(seed) | |
torch.manual_seed(seed) | |
torch.cuda.manual_seed_all(seed) | |
def visualize_fn(test_path, **kwargs): | |
with open(test_path, 'r') as f: | |
test_json = json.load(f) | |
# load top10_indices from json | |
selected_video_path_list = [] | |
# load video list from json | |
with open(test_path, 'r') as f: | |
video_list = json.load(f)["video_candidates"] | |
for idx, data in enumerate(test_json.items()): | |
top10_indices_path = os.path.join(test_path, f"audio_{idx}_retri_top10.json") | |
with open(top10_indices_path, 'r') as f: | |
top10_indices = json.load(f)["top10_indices"] | |
selected_video_path_list.append(video_list[top10_indices[0]]) | |
# moviepy load and add audio | |
video = VideoFileClip(video_list[top10_indices[0]]) | |
audio = AudioFileClip(data["audio_path"]) | |
video = video.set_audio(audio) | |
video.write_videofile(f"audio_{idx}_retri_top1.mp4") | |
video.close() | |
def prepare_all(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--config", type=str, default="./configs/train/stage2.yaml") | |
parser.add_argument("--debug", action="store_true", help="Enable debugging mode") | |
parser.add_argument('overrides', nargs=argparse.REMAINDER) | |
args = parser.parse_args() | |
if args.config.endswith(".yaml"): | |
config = OmegaConf.load(args.config) | |
# config.wandb_project = args.config.split("-")[1] | |
config.exp_name = args.config.split("/")[-1][:-5] | |
else: | |
raise ValueError("Unsupported config file format. Only .yaml files are allowed.") | |
if args.debug: | |
config.wandb_project = "debug" | |
config.exp_name = "debug" | |
config.solver.max_train_steps = 4 | |
if args.overrides: | |
for arg in args.overrides: | |
key, value = arg.split('=') | |
try: | |
value = eval(value) | |
except: | |
pass | |
if key in config: | |
config[key] = value | |
else: | |
raise ValueError(f"Key {key} not found in config.") | |
os.environ["WANDB_API_KEY"] = config.wandb_key | |
save_dir = os.path.join(config.output_dir, config.exp_name) | |
os.makedirs(save_dir, exist_ok=True) | |
os.makedirs(os.path.join(save_dir, 'sanity_check'), exist_ok=True) | |
config_path = os.path.join(save_dir, 'sanity_check', f'{config.exp_name}.yaml') | |
with open(config_path, 'w') as f: | |
OmegaConf.save(config, f) | |
current_dir = os.path.dirname(os.path.abspath(__file__)) | |
sanity_check_dir = os.path.join(save_dir, 'sanity_check') | |
for root, dirs, files in os.walk(current_dir): | |
for file in files: | |
if file.endswith(".py"): | |
full_file_path = os.path.join(root, file) | |
relative_path = os.path.relpath(full_file_path, current_dir) | |
dest_path = os.path.join(sanity_check_dir, relative_path) | |
os.makedirs(os.path.dirname(dest_path), exist_ok=True) | |
shutil.copy(full_file_path, dest_path) | |
return config | |
if __name__ == "__main__": | |
config = prepare_all() | |
main(config) |