Hiepppp's picture
Upload folder using huggingface_hub
344837c verified
import collections
import json
import math
import os
import re
import threading
from typing import List, Literal, Optional, Tuple, Union
from colorama import Fore, Style, init
init(autoreset=True)
import imageio.v3 as iio
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import repeat
from PIL import Image
from tqdm.auto import tqdm
from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose
from seva.sampling import (
Discretization,
EulerEDMSampler,
MultiviewCFG,
MultiviewTemporalCFG,
VanillaCFG,
)
from seva.utils import seed_everything
try:
# Check if version string contains 'dev' or 'nightly'
version = torch.__version__
IS_TORCH_NIGHTLY = "dev" in version
if IS_TORCH_NIGHTLY:
torch._dynamo.config.cache_size_limit = 128 # type: ignore[assignment]
torch._dynamo.config.accumulated_cache_size_limit = 1024 # type: ignore[assignment]
torch._dynamo.config.force_parameter_static_shapes = False # type: ignore[assignment]
except Exception:
IS_TORCH_NIGHTLY = False
def pad_indices(
input_indices: List[int],
test_indices: List[int],
T: int,
padding_mode: Literal["first", "last", "none"] = "last",
):
assert padding_mode in ["last", "none"], "`first` padding is not supported yet."
if padding_mode == "last":
padded_indices = [
i for i in range(T) if i not in (input_indices + test_indices)
]
else:
padded_indices = []
input_selects = list(range(len(input_indices)))
test_selects = list(range(len(test_indices)))
if max(input_indices) > max(test_indices):
# last elem from input
input_selects += [input_selects[-1]] * len(padded_indices)
input_indices = input_indices + padded_indices
sorted_inds = np.argsort(input_indices)
input_indices = [input_indices[ind] for ind in sorted_inds]
input_selects = [input_selects[ind] for ind in sorted_inds]
else:
# last elem from test
test_selects += [test_selects[-1]] * len(padded_indices)
test_indices = test_indices + padded_indices
sorted_inds = np.argsort(test_indices)
test_indices = [test_indices[ind] for ind in sorted_inds]
test_selects = [test_selects[ind] for ind in sorted_inds]
if padding_mode == "last":
input_maps = np.array([-1] * T)
test_maps = np.array([-1] * T)
else:
input_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
test_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
input_maps[input_indices] = input_selects
test_maps[test_indices] = test_selects
return input_indices, test_indices, input_maps, test_maps
def assemble(
input,
test,
input_maps,
test_maps,
):
T = len(input_maps)
assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0)
assembled[input_maps != -1] = input[input_maps[input_maps != -1]]
assembled[test_maps != -1] = test[test_maps[test_maps != -1]]
assert np.logical_xor(input_maps != -1, test_maps != -1).all()
return assembled
def get_resizing_factor(
target_shape: Tuple[int, int], # H, W
current_shape: Tuple[int, int], # H, W
cover_target: bool = True,
# If True, the output shape will fully cover the target shape.
# If No, the target shape will fully cover the output shape.
) -> float:
r_bound = target_shape[1] / target_shape[0]
aspect_r = current_shape[1] / current_shape[0]
if r_bound >= 1.0:
if cover_target:
if aspect_r >= r_bound:
factor = min(target_shape) / min(current_shape)
elif aspect_r < 1.0:
factor = max(target_shape) / min(current_shape)
else:
factor = max(target_shape) / max(current_shape)
else:
if aspect_r >= r_bound:
factor = max(target_shape) / max(current_shape)
elif aspect_r < 1.0:
factor = min(target_shape) / max(current_shape)
else:
factor = min(target_shape) / min(current_shape)
else:
if cover_target:
if aspect_r <= r_bound:
factor = min(target_shape) / min(current_shape)
elif aspect_r > 1.0:
factor = max(target_shape) / min(current_shape)
else:
factor = max(target_shape) / max(current_shape)
else:
if aspect_r <= r_bound:
factor = max(target_shape) / max(current_shape)
elif aspect_r > 1.0:
factor = min(target_shape) / max(current_shape)
else:
factor = min(target_shape) / min(current_shape)
return factor
def get_unique_embedder_keys_from_conditioner(conditioner):
keys = [x.input_key for x in conditioner.embedders if x.input_key is not None]
keys = [item for sublist in keys for item in sublist] # Flatten list
return set(keys)
def get_wh_with_fixed_shortest_side(w, h, size):
# size is smaller or equal to zero, we return original w h
if size is None or size <= 0:
return w, h
if w < h:
new_w = size
new_h = int(size * h / w)
else:
new_h = size
new_w = int(size * w / h)
return new_w, new_h
def load_img_and_K(
image_path_or_size: Union[str, torch.Size],
size: Optional[Union[int, Tuple[int, int]]],
scale: float = 1.0,
center: Tuple[float, float] = (0.5, 0.5),
K: torch.Tensor | None = None,
size_stride: int = 1,
center_crop: bool = False,
image_as_tensor: bool = True,
context_rgb: np.ndarray | None = None,
device: str = "cuda",
):
if isinstance(image_path_or_size, torch.Size):
image = Image.new("RGBA", image_path_or_size[::-1])
else:
image = Image.open(image_path_or_size).convert("RGBA")
w, h = image.size
if size is None:
size = (w, h)
image = np.array(image).astype(np.float32) / 255
if image.shape[-1] == 4:
rgb, alpha = image[:, :, :3], image[:, :, 3:]
if context_rgb is not None:
image = rgb * alpha + context_rgb * (1 - alpha)
else:
image = rgb * alpha + (1 - alpha)
image = image.transpose(2, 0, 1)
image = torch.from_numpy(image).to(dtype=torch.float32)
image = image.unsqueeze(0)
if isinstance(size, (tuple, list)):
# => if size is a tuple or list, we first rescale to fully cover the `size`
# area and then crop the `size` area from the rescale image
W, H = size
else:
# => if size is int, we rescale the image to fit the shortest side to size
# => if size is None, no rescaling is applied
W, H = get_wh_with_fixed_shortest_side(w, h, size)
W, H = (
math.floor(W / size_stride + 0.5) * size_stride,
math.floor(H / size_stride + 0.5) * size_stride,
)
rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
image = torch.nn.functional.interpolate(
image, resize_size, mode="area", antialias=False
)
if scale < 1.0:
pw = math.ceil((W - resize_size[1]) * 0.5)
ph = math.ceil((H - resize_size[0]) * 0.5)
image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)
cy_center = int(center[1] * image.shape[-2])
cx_center = int(center[0] * image.shape[-1])
if center_crop:
side = min(H, W)
ct = max(0, cy_center - side // 2)
cl = max(0, cx_center - side // 2)
ct = min(ct, image.shape[-2] - side)
cl = min(cl, image.shape[-1] - side)
image = TF.crop(image, top=ct, left=cl, height=side, width=side)
else:
ct = max(0, cy_center - H // 2)
cl = max(0, cx_center - W // 2)
ct = min(ct, image.shape[-2] - H)
cl = min(cl, image.shape[-1] - W)
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
if K is not None:
K = K.clone()
if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
K[:2] *= K.new_tensor([rw, rh])[:, None] # normalized K
else:
K[:2] *= K.new_tensor([rw / w, rh / h])[:, None] # unnormalized K
K[:2, 2] -= K.new_tensor([cl, ct])
if image_as_tensor:
# tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
image = image.to(device) * 2.0 - 1.0
else:
# PIL Image with values ranging from (0, 255)
image = image.permute(0, 2, 3, 1).numpy()[0]
image = Image.fromarray((image * 255).astype(np.uint8))
return image, K
def transform_img_and_K(
image: torch.Tensor,
size: Union[int, Tuple[int, int]],
scale: float = 1.0,
center: Tuple[float, float] = (0.5, 0.5),
K: torch.Tensor | None = None,
size_stride: int = 1,
mode: str = "crop",
):
assert mode in [
"crop",
"pad",
"stretch",
], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"
h, w = image.shape[-2:]
if isinstance(size, (tuple, list)):
# => if size is a tuple or list, we first rescale to fully cover the `size`
# area and then crop the `size` area from the rescale image
W, H = size
else:
# => if size is int, we rescale the image to fit the shortest side to size
# => if size is None, no rescaling is applied
W, H = get_wh_with_fixed_shortest_side(w, h, size)
W, H = (
math.floor(W / size_stride + 0.5) * size_stride,
math.floor(H / size_stride + 0.5) * size_stride,
)
if mode == "stretch":
rh, rw = H, W
else:
rfs = get_resizing_factor(
(H, W),
(h, w),
cover_target=mode != "pad",
)
(rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]
rh, rw = int(rh / scale), int(rw / scale)
image = torch.nn.functional.interpolate(
image, (rh, rw), mode="area", antialias=False
)
cy_center = int(center[1] * image.shape[-2])
cx_center = int(center[0] * image.shape[-1])
if mode != "pad":
ct = max(0, cy_center - H // 2)
cl = max(0, cx_center - W // 2)
ct = min(ct, image.shape[-2] - H)
cl = min(cl, image.shape[-1] - W)
image = TF.crop(image, top=ct, left=cl, height=H, width=W)
pl, pt = 0, 0
else:
pt = max(0, H // 2 - cy_center)
pl = max(0, W // 2 - cx_center)
pb = max(0, H - pt - image.shape[-2])
pr = max(0, W - pl - image.shape[-1])
image = TF.pad(
image,
[pl, pt, pr, pb],
)
cl, ct = 0, 0
if K is not None:
K = K.clone()
# K[:, :2, 2] += K.new_tensor([pl, pt])
if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
K[:, :2] *= K.new_tensor([rw, rh])[None, :, None] # normalized K
else:
K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None] # unnormalized K
K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])
return image, K
lowvram_mode = False
def set_lowvram_mode(mode):
global lowvram_mode
lowvram_mode = mode
def load_model(model, device: str = "cuda"):
model.to(device)
def unload_model(model):
global lowvram_mode
if lowvram_mode:
model.cpu()
torch.cuda.empty_cache()
def infer_prior_stats(
T,
num_input_frames,
num_total_frames,
version_dict,
):
options = version_dict["options"]
chunk_strategy = options.get("chunk_strategy", "nearest")
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
# get traj_prior_c2ws for 2-pass sampling
if chunk_strategy.startswith("interp"):
# Start and end have alreay taken up two slots
# +1 means we need X + 1 prior frames to bound X times forwards for all test frames
# Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the
# generated frames due to insufficient prior frames. This option is effective for
# complicated trajectory and when `interp` strategy is used (usually semi-dense-view
# regime). Recommended range is [1.0 (default), 1.5].
if num_input_frames >= options.get("num_input_semi_dense", 9):
num_prior_frames = (
math.ceil(
num_total_frames
/ (T_second_pass - 2)
* options.get("num_prior_frames_ratio", 1.0)
)
+ 1
)
if num_prior_frames + num_input_frames < T_first_pass:
num_prior_frames = T_first_pass - num_input_frames
num_prior_frames = max(
num_prior_frames,
options.get("num_prior_frames", 0),
)
T_first_pass = num_prior_frames + num_input_frames
if "gt" in chunk_strategy:
T_second_pass = T_second_pass + num_input_frames
# Dynamically update context window length.
version_dict["T"] = [T_first_pass, T_second_pass]
else:
num_prior_frames = (
math.ceil(
num_total_frames
/ (
T_second_pass
- 2
- (num_input_frames if "gt" in chunk_strategy else 0)
)
* options.get("num_prior_frames_ratio", 1.0)
)
+ 1
)
if num_prior_frames + num_input_frames < T_first_pass:
num_prior_frames = T_first_pass - num_input_frames
num_prior_frames = max(
num_prior_frames,
options.get("num_prior_frames", 0),
)
else:
num_prior_frames = max(
T_first_pass - num_input_frames,
options.get("num_prior_frames", 0),
)
if num_input_frames >= options.get("num_input_semi_dense", 9):
T_first_pass = num_prior_frames + num_input_frames
# Dynamically update context window length.
version_dict["T"] = [T_first_pass, T_second_pass]
return num_prior_frames
def infer_prior_inds(
c2ws,
num_prior_frames,
input_frame_indices,
options,
):
chunk_strategy = options.get("chunk_strategy", "nearest")
if chunk_strategy.startswith("interp"):
prior_frame_indices = np.array(
[i for i in range(c2ws.shape[0]) if i not in input_frame_indices]
)
prior_frame_indices = prior_frame_indices[
np.ceil(
np.linspace(
0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True
)
).astype(int)
] # having a ceil here is actually safer for corner case
else:
prior_frame_indices = []
while len(prior_frame_indices) < num_prior_frames:
closest_distance = np.abs(
np.arange(c2ws.shape[0])[None]
- np.concatenate(
[np.array(input_frame_indices), np.array(prior_frame_indices)]
)[:, None]
).min(0)
prior_frame_indices.append(np.argsort(closest_distance)[-1])
return np.sort(prior_frame_indices)
def compute_relative_inds(
source_inds,
target_inds,
):
assert len(source_inds) > 2
# compute relative indices of target_inds within source_inds
relative_inds = []
for ind in target_inds:
if ind in source_inds:
relative_ind = int(np.where(source_inds == ind)[0][0])
elif ind < source_inds[0]:
# extrapolate
relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0]))
elif ind > source_inds[-1]:
# extrapolate
relative_ind = len(source_inds) + (
(ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2])
)
else:
# interpolate
lower_inds = source_inds[source_inds < ind]
upper_inds = source_inds[source_inds > ind]
if len(lower_inds) > 0 and len(upper_inds) > 0:
lower_ind = lower_inds[-1]
upper_ind = upper_inds[0]
relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0])
relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0])
relative_ind = relative_lower_ind + (ind - lower_ind) / (
upper_ind - lower_ind
) * (relative_upper_ind - relative_lower_ind)
else:
# Out of range
relative_inds.append(float("nan")) # Or some other placeholder
relative_inds.append(relative_ind)
return relative_inds
def find_nearest_source_inds(
source_c2ws,
target_c2ws,
nearest_num=1,
mode="translation",
):
dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy()
sorted_inds = np.argsort(dists, axis=0).T
return sorted_inds[:, :nearest_num]
def chunk_input_and_test(
T,
input_c2ws,
test_c2ws,
input_ords, # orders
test_ords, # orders
options,
task: str = "img2img",
chunk_strategy: str = "gt",
gt_input_inds: list = [],
):
M, N = input_c2ws.shape[0], test_c2ws.shape[0]
chunks = []
if chunk_strategy.startswith("gt"):
assert len(gt_input_inds) < T, (
f"Number of gt input frames {len(gt_input_inds)} should be "
f"less than {T} when `gt` chunking strategy is used."
)
assert (
list(range(M)) == gt_input_inds
), "All input_c2ws should be gt when `gt` chunking strategy is used."
num_test_seen = 0
while num_test_seen < N:
chunk = [f"!{i:03d}" for i in gt_input_inds]
if chunk_strategy != "gt" and num_test_seen > 0:
pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33)
if (N - num_test_seen) >= math.floor(
(T - len(gt_input_inds)) * pseudo_num_ratio
):
pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio)
else:
pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen)
pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000))
if "ltr" in chunk_strategy:
chunk.extend(
[
f"!{i + len(gt_input_inds):03d}"
for i in range(num_test_seen - pseudo_num, num_test_seen)
]
)
elif "nearest" in chunk_strategy:
source_inds = np.concatenate(
[
find_nearest_source_inds(
test_c2ws[:num_test_seen],
test_c2ws[num_test_seen:],
nearest_num=1, # pseudo_num,
mode="rotation",
),
find_nearest_source_inds(
test_c2ws[:num_test_seen],
test_c2ws[num_test_seen:],
nearest_num=1, # pseudo_num,
mode="translation",
),
],
axis=1,
)
####### [HACK ALERT] keep running until pseudo num is stablized ########
temp_pseudo_num = pseudo_num
while True:
nearest_source_inds = np.concatenate(
[
np.sort(
[
ind
for (ind, _) in collections.Counter(
[
item
for item in source_inds[
: T
- len(gt_input_inds)
- temp_pseudo_num
]
.flatten()
.tolist()
if item
!= (
num_test_seen - 1
) # exclude the last one here
]
).most_common(pseudo_num - 1)
],
).astype(int),
[num_test_seen - 1], # always keep the last one
]
)
if len(nearest_source_inds) >= temp_pseudo_num:
break # stablized
else:
temp_pseudo_num = len(nearest_source_inds)
pseudo_num = len(nearest_source_inds)
########################################################################
chunk.extend(
[f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds]
)
else:
raise NotImplementedError(
f"Chunking strategy {chunk_strategy} for the first pass is not implemented."
)
chunk.extend(
[
f">{i:03d}"
for i in range(
num_test_seen,
min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N),
)
]
)
else:
chunk.extend(
[
f">{i:03d}"
for i in range(
num_test_seen,
min(num_test_seen + T - len(gt_input_inds), N),
)
]
)
num_test_seen += sum([1 for c in chunk if c.startswith(">")])
if len(chunk) < T:
chunk.extend(["NULL"] * (T - len(chunk)))
chunks.append(chunk)
elif chunk_strategy.startswith("nearest"):
input_imgs = np.array([f"!{i:03d}" for i in range(M)])
test_imgs = np.array([f">{i:03d}" for i in range(N)])
match = re.match(r"^nearest-(\d+)$", chunk_strategy)
if match:
nearest_num = int(match.group(1))
assert (
nearest_num < T
), f"Nearest number of {nearest_num} should be less than {T}."
source_inds = find_nearest_source_inds(
input_c2ws,
test_c2ws,
nearest_num=nearest_num,
mode="translation", # during the second pass, consider translation only is enough
)
for i in range(0, N, T - nearest_num):
nearest_source_inds = np.sort(
[
ind
for (ind, _) in collections.Counter(
source_inds[i : i + T - nearest_num].flatten().tolist()
).most_common(nearest_num)
]
)
chunk = (
input_imgs[nearest_source_inds].tolist()
+ test_imgs[i : i + T - nearest_num].tolist()
)
chunks.append(chunk + ["NULL"] * (T - len(chunk)))
else:
# do not always condition on gt cond frames
if "gt" not in chunk_strategy:
gt_input_inds = []
source_inds = find_nearest_source_inds(
input_c2ws,
test_c2ws,
nearest_num=1,
mode="translation", # during the second pass, consider translation only is enough
)[:, 0]
test_inds_per_input = {}
for test_idx, input_idx in enumerate(source_inds):
if input_idx not in test_inds_per_input:
test_inds_per_input[input_idx] = []
test_inds_per_input[input_idx].append(test_idx)
num_test_seen = 0
chunk = input_imgs[gt_input_inds].tolist()
candidate_input_inds = sorted(list(test_inds_per_input.keys()))
while num_test_seen < N:
input_idx = candidate_input_inds[0]
test_inds = test_inds_per_input[input_idx]
input_is_cond = input_idx in gt_input_inds
prefix_inds = [] if input_is_cond else [input_idx]
if len(chunk) == T - len(prefix_inds) or not candidate_input_inds:
if chunk:
chunk += ["NULL"] * (T - len(chunk))
chunks.append(chunk)
chunk = input_imgs[gt_input_inds].tolist()
if num_test_seen >= N:
break
continue
candidate_chunk = (
input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist()
)
space_left = T - len(chunk)
if len(candidate_chunk) <= space_left:
chunk.extend(candidate_chunk)
num_test_seen += len(test_inds)
candidate_input_inds.pop(0)
else:
chunk.extend(candidate_chunk[:space_left])
num_input_idx = 0 if input_is_cond else 1
num_test_seen += space_left - num_input_idx
test_inds_per_input[input_idx] = test_inds[
space_left - num_input_idx :
]
if len(chunk) == T:
chunks.append(chunk)
chunk = input_imgs[gt_input_inds].tolist()
if chunk and chunk != input_imgs[gt_input_inds].tolist():
chunks.append(chunk + ["NULL"] * (T - len(chunk)))
elif chunk_strategy.startswith("interp"):
# `interp` chunk requires ordering info
assert input_ords is not None and test_ords is not None, (
"When using `interp` chunking strategy, ordering of input "
"and test frames should be provided."
)
# if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not
# use input views since their order info within target views is unknown
if "img2trajvid" in task:
assert (
list(range(len(gt_input_inds))) == gt_input_inds
), "`img2trajvid` task should put `gt_input_inds` in start."
input_c2ws = input_c2ws[
[ind for ind in range(M) if ind not in gt_input_inds]
]
input_ords = [
input_ords[ind] for ind in range(M) if ind not in gt_input_inds
]
M = input_c2ws.shape[0]
input_ords = [0] + input_ords # this is a hack accounting for test views
# before the first input view
input_ords[-1] += 0.01 # this is a hack ensuring last test stop is included
# in the last forward when input_ords[-1] == test_ords[-1]
input_ords = np.array(input_ords)[:, None]
input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)])
test_ords = np.array(test_ords)[None]
in_stop_ranges = np.logical_and(
np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0),
np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0),
) # (M, N)
assert (in_stop_ranges.sum(1) <= T - 2).all(), (
"More anchor frames need to be sampled during the first pass to ensure "
f"#target frames during each forward in the second pass will not exceed {T - 2}."
)
if input_ords[1, 0] <= test_ords[0, 0]:
assert not in_stop_ranges[0].any()
if input_ords[-1, 0] >= test_ords[0, -1]:
assert not in_stop_ranges[-1].any()
gt_chunk = (
[f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else []
)
chunk = gt_chunk + []
# any test views before the first input views
if in_stop_ranges[0].any():
for j, in_range in enumerate(in_stop_ranges[0]):
if in_range:
chunk.append(f">{j:03d}")
in_stop_ranges = in_stop_ranges[1:]
i = 0
base_i = len(gt_input_inds) if "img2trajvid" in task else 0
chunk.append(f"!{i + base_i:03d}")
while i < len(in_stop_ranges):
in_stop_range = in_stop_ranges[i]
if not in_stop_range.any():
i += 1
continue
input_left = i + 1 < M
space_left = T - len(chunk)
if sum(in_stop_range) + input_left <= space_left:
for j, in_range in enumerate(in_stop_range):
if in_range:
chunk.append(f">{j:03d}")
i += 1
if input_left:
chunk.append(f"!{i + base_i:03d}")
else:
chunk += ["NULL"] * space_left
chunks.append(chunk)
chunk = gt_chunk + [f"!{i + base_i:03d}"]
if len(chunk) > 1:
chunk += ["NULL"] * (T - len(chunk))
chunks.append(chunk)
else:
raise NotImplementedError
(
input_inds_per_chunk,
input_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
) = (
[],
[],
[],
[],
)
for chunk in chunks:
input_inds = [
int(img.removeprefix("!")) for img in chunk if img.startswith("!")
]
input_sels = [chunk.index(img) for img in chunk if img.startswith("!")]
test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")]
test_sels = [chunk.index(img) for img in chunk if img.startswith(">")]
input_inds_per_chunk.append(input_inds)
input_sels_per_chunk.append(input_sels)
test_inds_per_chunk.append(test_inds)
test_sels_per_chunk.append(test_sels)
if options.get("sampler_verbose", True):
def colorize(item):
if item.startswith("!"):
return f"{Fore.RED}{item}{Style.RESET_ALL}" # Red for items starting with '!'
elif item.startswith(">"):
return f"{Fore.GREEN}{item}{Style.RESET_ALL}" # Green for items starting with '>'
return item # Default color if neither '!' nor '>'
print("\nchunks:")
for chunk in chunks:
print(", ".join(colorize(item) for item in chunk))
return (
chunks,
input_inds_per_chunk, # ordering of input in raw sequence
input_sels_per_chunk, # ordering of input in one-forward sequence of length T
test_inds_per_chunk, # ordering of test in raw sequence
test_sels_per_chunk, # oredering of test in one-forward sequence of length T
)
def is_k_in_dict(d, k):
return any(map(lambda x: x.startswith(k), d.keys()))
def get_k_from_dict(d, k):
media_d = {}
for key, value in d.items():
if key == k:
return value
if key.startswith(k):
media = key.split("/")[-1]
if media == "raw":
return value
media_d[media] = value
if len(media_d) == 0:
return torch.tensor([])
assert (
len(media_d) == 1
), f"multiple media found in {d} for key {k}: {media_d.keys()}"
return media_d[media]
def update_kv_for_dict(d, k, v):
for key in d.keys():
if key.startswith(k):
d[key] = v
return d
def extend_dict(ds, d):
for key in d.keys():
if key in ds:
ds[key] = torch.cat([ds[key], d[key]], 0)
else:
ds[key] = d[key]
return ds
def replace_or_include_input_for_dict(
samples,
test_indices,
imgs,
c2w,
K,
):
samples_new = {}
for sample, value in samples.items():
if "rgb" in sample:
imgs[test_indices] = (
value[test_indices] if value.shape[0] == imgs.shape[0] else value
).to(device=imgs.device, dtype=imgs.dtype)
samples_new[sample] = imgs
elif "c2w" in sample:
c2w[test_indices] = (
value[test_indices] if value.shape[0] == c2w.shape[0] else value
).to(device=c2w.device, dtype=c2w.dtype)
samples_new[sample] = c2w
elif "intrinsics" in sample:
K[test_indices] = (
value[test_indices] if value.shape[0] == K.shape[0] else value
).to(device=K.device, dtype=K.dtype)
samples_new[sample] = K
else:
samples_new[sample] = value
return samples_new
def decode_output(
samples,
T,
indices=None,
):
# decode model output into dict if it is not
if isinstance(samples, dict):
# model with postprocessor and outputs dict
for sample, value in samples.items():
if isinstance(value, torch.Tensor):
value = value.detach().cpu()
elif isinstance(value, np.ndarray):
value = torch.from_numpy(value)
else:
value = torch.tensor(value)
if indices is not None and value.shape[0] == T:
value = value[indices]
samples[sample] = value
else:
# model without postprocessor and outputs tensor (rgb)
samples = samples.detach().cpu()
if indices is not None and samples.shape[0] == T:
samples = samples[indices]
samples = {"samples-rgb/image": samples}
return samples
def save_output(
samples,
save_path,
video_save_fps=2,
):
os.makedirs(save_path, exist_ok=True)
for sample in samples:
media_type = "video"
if "/" in sample:
sample_, media_type = sample.split("/")
else:
sample_ = sample
value = samples[sample]
if isinstance(value, torch.Tensor):
value = value.detach().cpu()
elif isinstance(value, np.ndarray):
value = torch.from_numpy(value)
else:
value = torch.tensor(value)
if media_type == "image":
value = (value.permute(0, 2, 3, 1) + 1) / 2.0
value = (value * 255).clamp(0, 255).to(torch.uint8)
iio.imwrite(
os.path.join(save_path, f"{sample_}.mp4")
if sample_
else f"{save_path}.mp4",
value,
fps=video_save_fps,
macro_block_size=1,
ffmpeg_log_level="error",
)
os.makedirs(os.path.join(save_path, sample_), exist_ok=True)
for i, s in enumerate(value):
iio.imwrite(
os.path.join(save_path, sample_, f"{i:03d}.png"),
s,
)
elif media_type == "video":
value = (value.permute(0, 2, 3, 1) + 1) / 2.0
value = (value * 255).clamp(0, 255).to(torch.uint8)
iio.imwrite(
os.path.join(save_path, f"{sample_}.mp4"),
value,
fps=video_save_fps,
macro_block_size=1,
ffmpeg_log_level="error",
)
elif media_type == "raw":
torch.save(
value,
os.path.join(save_path, f"{sample_}.pt"),
)
else:
pass
def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks):
import os.path as osp
out_frames = []
for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks):
out_frame = {
"fl_x": K[0][0].item(),
"fl_y": K[1][1].item(),
"cx": K[0][2].item(),
"cy": K[1][2].item(),
"w": img_wh[0].item(),
"h": img_wh[1].item(),
"file_path": f"./{osp.relpath(img_path, start=save_path)}"
if img_path is not None
else None,
"transform_matrix": c2w.tolist(),
}
out_frames.append(out_frame)
out = {
# "camera_model": "PINHOLE",
"orientation_override": "none",
"frames": out_frames,
}
with open(osp.join(save_path, "transforms.json"), "w") as of:
json.dump(out, of, indent=5)
def create_samplers(
guider_types: int | list[int],
discretization: Discretization,
num_frames: list[int] | None,
num_steps: int,
cfg_min: float = 1.0,
device: str | torch.device = "cuda",
abort_event: threading.Event | None = None,
):
guider_mapping = {
0: VanillaCFG,
1: MultiviewCFG,
2: MultiviewTemporalCFG,
}
samplers = []
if not isinstance(guider_types, (list, tuple)):
guider_types = [guider_types]
for i, guider_type in enumerate(guider_types):
if guider_type not in guider_mapping:
raise ValueError(
f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}"
)
guider_args = ()
if guider_type > 0:
guider_args += (cfg_min,)
if guider_type == 2:
assert num_frames is not None
guider_args = (num_frames[i], cfg_min)
guider = guider_mapping[guider_type](*guider_args)
sampler = EulerEDMSampler(
abort_event=abort_event,
discretization=discretization,
guider=guider,
num_steps=num_steps,
s_churn=0.0,
s_tmin=0.0,
s_tmax=999.0,
s_noise=1.0,
verbose=True,
device=device,
)
samplers.append(sampler)
return samplers
def get_value_dict(
curr_imgs,
curr_input_frame_indices,
curr_c2ws,
curr_Ks,
curr_input_camera_indices,
all_c2ws,
camera_scale,
):
assert sorted(curr_input_camera_indices) == sorted(
range(len(curr_input_camera_indices))
)
H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8
value_dict = {}
value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
value_dict["cond_frames_mask"][curr_input_frame_indices] = True
value_dict["cond_aug"] = 0.0
c2w = to_hom_pose(curr_c2ws.float())
w2c = torch.linalg.inv(c2w)
# camera centering
ref_c2ws = all_c2ws
camera_dist_2med = torch.norm(
ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
dim=-1,
)
valid_mask = camera_dist_2med <= torch.clamp(
torch.quantile(camera_dist_2med, 0.97) * 10,
max=1e6,
)
c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
w2c = torch.linalg.inv(c2w)
# camera normalization
camera_dists = c2w[:, :3, 3].clone()
translation_scaling_factor = (
camera_scale
if torch.isclose(
torch.norm(camera_dists[0]),
torch.zeros(1),
atol=1e-5,
).any()
else (camera_scale / torch.norm(camera_dists[0]))
)
w2c[:, :3, 3] *= translation_scaling_factor
c2w[:, :3, 3] *= translation_scaling_factor
value_dict["plucker_coordinate"] = get_plucker_coordinates(
extrinsics_src=w2c[0],
extrinsics=w2c,
intrinsics=curr_Ks.float().clone(),
target_size=(H // F, W // F),
)
value_dict["c2w"] = c2w
value_dict["K"] = curr_Ks
value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
value_dict["camera_mask"][curr_input_camera_indices] = True
return value_dict
def do_sample(
model,
ae,
conditioner,
denoiser,
sampler,
value_dict,
H,
W,
C,
F,
T,
cfg,
encoding_t=1,
decoding_t=1,
verbose=True,
global_pbar=None,
**_,
):
imgs = value_dict["cond_frames"].to("cuda")
input_masks = value_dict["cond_frames_mask"].to("cuda")
pluckers = value_dict["plucker_coordinate"].to("cuda")
num_samples = [1, T]
with torch.inference_mode(), torch.autocast("cuda"):
load_model(ae)
load_model(conditioner)
latents = torch.nn.functional.pad(
ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0
)
c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T)
uc_crossattn = torch.zeros_like(c_crossattn)
c_replace = latents.new_zeros(T, *latents.shape[1:])
c_replace[input_masks] = latents
uc_replace = torch.zeros_like(c_replace)
c_concat = torch.cat(
[
repeat(
input_masks,
"n -> n 1 h w",
h=pluckers.shape[2],
w=pluckers.shape[3],
),
pluckers,
],
1,
)
uc_concat = torch.cat(
[pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1
)
c_dense_vector = pluckers
uc_dense_vector = c_dense_vector
c = {
"crossattn": c_crossattn,
"replace": c_replace,
"concat": c_concat,
"dense_vector": c_dense_vector,
}
uc = {
"crossattn": uc_crossattn,
"replace": uc_replace,
"concat": uc_concat,
"dense_vector": uc_dense_vector,
}
unload_model(ae)
unload_model(conditioner)
additional_model_inputs = {"num_frames": T}
additional_sampler_inputs = {
"c2w": value_dict["c2w"].to("cuda"),
"K": value_dict["K"].to("cuda"),
"input_frame_mask": value_dict["cond_frames_mask"].to("cuda"),
}
if global_pbar is not None:
additional_sampler_inputs["global_pbar"] = global_pbar
shape = (math.prod(num_samples), C, H // F, W // F)
randn = torch.randn(shape).to("cuda")
load_model(model)
samples_z = sampler(
lambda input, sigma, c: denoiser(
model,
input,
sigma,
c,
**additional_model_inputs,
),
randn,
scale=cfg,
cond=c,
uc=uc,
verbose=verbose,
**additional_sampler_inputs,
)
if samples_z is None:
return
unload_model(model)
load_model(ae)
samples = ae.decode(samples_z, decoding_t)
unload_model(ae)
return samples
def run_one_scene(
task,
version_dict,
model,
ae,
conditioner,
denoiser,
image_cond,
camera_cond,
save_path,
use_traj_prior,
traj_prior_Ks,
traj_prior_c2ws,
seed=23,
gradio=False,
abort_event=None,
first_pass_pbar=None,
second_pass_pbar=None,
):
H, W, T, C, F, options = (
version_dict["H"],
version_dict["W"],
version_dict["T"],
version_dict["C"],
version_dict["f"],
version_dict["options"],
)
if isinstance(image_cond, str):
image_cond = {"img": [image_cond]}
imgs, img_size = [], None
for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])):
if isinstance(img, str) or img is None:
img, K = load_img_and_K(img or img_size, None, K=K, device="cpu") # type: ignore
img_size = img.shape[-2:]
if options.get("L_short", -1) == -1:
img, K = transform_img_and_K(
img,
(W, H),
K=K[None],
mode=(
options.get("transform_input", "crop")
if i in image_cond["input_indices"]
else options.get("transform_target", "crop")
),
scale=(
1.0
if i in image_cond["input_indices"]
else options.get("transform_scale", 1.0)
),
)
else:
downsample = 3
assert options["L_short"] % F * 2**downsample == 0, (
"Short side of the image should be divisible by "
f"F*2**{downsample}={F * 2**downsample}."
)
img, K = transform_img_and_K(
img,
options["L_short"],
K=K[None],
size_stride=F * 2**downsample,
mode=(
options.get("transform_input", "crop")
if i in image_cond["input_indices"]
else options.get("transform_target", "crop")
),
scale=(
1.0
if i in image_cond["input_indices"]
else options.get("transform_scale", 1.0)
),
)
version_dict["W"] = W = img.shape[-1]
version_dict["H"] = H = img.shape[-2]
K = K[0]
K[0] /= W
K[1] /= H
camera_cond["K"][i] = K
elif isinstance(img, np.ndarray):
img_size = torch.Size(img.shape[:2])
img = torch.as_tensor(img).permute(2, 0, 1)
img = img.unsqueeze(0)
img = img / 255.0 * 2.0 - 1.0
if not gradio:
img, K = transform_img_and_K(img, (W, H), K=K[None])
assert K is not None
K = K[0]
K[0] /= W
K[1] /= H
camera_cond["K"][i] = K
else:
assert (
False
), f"Variable `img` got {type(img)} type which is not supported!!!"
imgs.append(img)
imgs = torch.cat(imgs, dim=0)
if traj_prior_Ks is not None:
assert img_size is not None
for i, prior_k in enumerate(traj_prior_Ks):
img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu") # type: ignore
img, prior_k = transform_img_and_K(
img,
(W, H),
K=prior_k[None],
mode=options.get(
"transform_target", "crop"
), # mode for prior is always same as target
scale=options.get(
"transform_scale", 1.0
), # scale for prior is always same as target
)
prior_k = prior_k[0]
prior_k[0] /= W
prior_k[1] /= H
traj_prior_Ks[i] = prior_k
options["num_frames"] = T
torch.cuda.empty_cache()
seed_everything(seed)
# Get Data
input_indices = image_cond["input_indices"]
input_imgs = imgs[input_indices]
input_c2ws = camera_cond["c2w"][input_indices]
input_Ks = camera_cond["K"][input_indices]
test_indices = [i for i in range(len(imgs)) if i not in input_indices]
test_imgs = imgs[test_indices]
test_c2ws = camera_cond["c2w"][test_indices]
test_Ks = camera_cond["K"][test_indices]
if options.get("save_input", True):
save_output(
{"/image": input_imgs},
save_path=os.path.join(save_path, "input"),
video_save_fps=2,
)
if not use_traj_prior:
chunk_strategy = options.get("chunk_strategy", "gt")
(
_,
input_inds_per_chunk,
input_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
) = chunk_input_and_test(
T,
input_c2ws,
test_c2ws,
input_indices,
test_indices,
options=options,
task=task,
chunk_strategy=chunk_strategy,
gt_input_inds=list(range(input_c2ws.shape[0])),
)
print(
f"One pass - chunking with `{chunk_strategy}` strategy: total "
f"{len(input_inds_per_chunk)} forward(s) ..."
)
all_samples = {}
all_test_inds = []
for i, (
chunk_input_inds,
chunk_input_sels,
chunk_test_inds,
chunk_test_sels,
) in tqdm(
enumerate(
zip(
input_inds_per_chunk,
input_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
)
),
total=len(input_inds_per_chunk),
leave=False,
):
(
curr_input_sels,
curr_test_sels,
curr_input_maps,
curr_test_maps,
) = pad_indices(
chunk_input_sels,
chunk_test_sels,
T=T,
padding_mode=options.get("t_padding_mode", "last"),
)
curr_imgs, curr_c2ws, curr_Ks = [
assemble(
input=x[chunk_input_inds],
test=y[chunk_test_inds],
input_maps=curr_input_maps,
test_maps=curr_test_maps,
)
for x, y in zip(
[
torch.cat(
[
input_imgs,
get_k_from_dict(all_samples, "samples-rgb").to(
input_imgs.device
),
],
dim=0,
),
torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0),
torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0),
], # procedually append generated prior views to the input views
[test_imgs, test_c2ws, test_Ks],
)
]
value_dict = get_value_dict(
curr_imgs.to("cuda"),
curr_input_sels
+ [
sel
for (ind, sel) in zip(
np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
curr_test_sels,
)
if test_indices[ind] in image_cond["input_indices"]
],
curr_c2ws,
curr_Ks,
curr_input_sels
+ [
sel
for (ind, sel) in zip(
np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
curr_test_sels,
)
if test_indices[ind] in camera_cond["input_indices"]
],
all_c2ws=camera_cond["c2w"],
camera_scale=options.get("camera_scale", 2.0),
)
samplers = create_samplers(
options["guider_types"],
denoiser.discretization,
[len(curr_imgs)],
options["num_steps"],
options["cfg_min"],
abort_event=abort_event,
)
assert len(samplers) == 1
samples = do_sample(
model,
ae,
conditioner,
denoiser,
samplers[0],
value_dict,
H,
W,
C,
F,
T=len(curr_imgs),
cfg=(
options["cfg"][0]
if isinstance(options["cfg"], (list, tuple))
else options["cfg"]
),
**{k: options[k] for k in options if k not in ["cfg", "T"]},
)
samples = decode_output(
samples, len(curr_imgs), chunk_test_sels
) # decode into dict
if options.get("save_first_pass", False):
save_output(
replace_or_include_input_for_dict(
samples,
chunk_test_sels,
curr_imgs,
curr_c2ws,
curr_Ks,
),
save_path=os.path.join(save_path, "first-pass", f"forward_{i}"),
video_save_fps=2,
)
extend_dict(all_samples, samples)
all_test_inds.extend(chunk_test_inds)
else:
assert traj_prior_c2ws is not None, (
"`traj_prior_c2ws` should be set when using 2-pass sampling. One "
"potential reason is that the amount of input frames is larger than "
"T. Set `num_prior_frames` manually to overwrite the infered stats."
)
traj_prior_c2ws = torch.as_tensor(
traj_prior_c2ws,
device=input_c2ws.device,
dtype=input_c2ws.dtype,
)
if traj_prior_Ks is None:
traj_prior_Ks = test_Ks[:1].repeat_interleave(
traj_prior_c2ws.shape[0], dim=0
)
traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:])
# ---------------------------------- first pass ----------------------------------
T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
chunk_strategy_first_pass = options.get(
"chunk_strategy_first_pass", "gt-nearest"
)
(
_,
input_inds_per_chunk,
input_sels_per_chunk,
prior_inds_per_chunk,
prior_sels_per_chunk,
) = chunk_input_and_test(
T_first_pass,
input_c2ws,
traj_prior_c2ws,
input_indices,
image_cond["prior_indices"],
options=options,
task=task,
chunk_strategy=chunk_strategy_first_pass,
gt_input_inds=list(range(input_c2ws.shape[0])),
)
print(
f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total "
f"{len(input_inds_per_chunk)} forward(s) ..."
)
all_samples = {}
all_prior_inds = []
for i, (
chunk_input_inds,
chunk_input_sels,
chunk_prior_inds,
chunk_prior_sels,
) in tqdm(
enumerate(
zip(
input_inds_per_chunk,
input_sels_per_chunk,
prior_inds_per_chunk,
prior_sels_per_chunk,
)
),
total=len(input_inds_per_chunk),
leave=False,
):
(
curr_input_sels,
curr_prior_sels,
curr_input_maps,
curr_prior_maps,
) = pad_indices(
chunk_input_sels,
chunk_prior_sels,
T=T_first_pass,
padding_mode=options.get("t_padding_mode", "last"),
)
curr_imgs, curr_c2ws, curr_Ks = [
assemble(
input=x[chunk_input_inds],
test=y[chunk_prior_inds],
input_maps=curr_input_maps,
test_maps=curr_prior_maps,
)
for x, y in zip(
[
torch.cat(
[
input_imgs,
get_k_from_dict(all_samples, "samples-rgb").to(
input_imgs.device
),
],
dim=0,
),
torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0),
torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0),
], # procedually append generated prior views to the input views
[
traj_prior_imgs,
traj_prior_c2ws,
traj_prior_Ks,
],
)
]
value_dict = get_value_dict(
curr_imgs.to("cuda"),
curr_input_sels,
curr_c2ws,
curr_Ks,
list(range(T_first_pass)),
all_c2ws=camera_cond["c2w"],
camera_scale=options.get("camera_scale", 2.0),
)
samplers = create_samplers(
options["guider_types"],
denoiser.discretization,
[T_first_pass, T_second_pass],
options["num_steps"],
options["cfg_min"],
abort_event=abort_event,
)
samples = do_sample(
model,
ae,
conditioner,
denoiser,
(
samplers[1]
if len(samplers) > 1
and options.get("ltr_first_pass", False)
and chunk_strategy_first_pass != "gt"
and i > 0
else samplers[0]
),
value_dict,
H,
W,
C,
F,
cfg=(
options["cfg"][0]
if isinstance(options["cfg"], (list, tuple))
else options["cfg"]
),
T=T_first_pass,
global_pbar=first_pass_pbar,
**{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
)
if samples is None:
return
samples = decode_output(
samples, T_first_pass, chunk_prior_sels
) # decode into dict
extend_dict(all_samples, samples)
all_prior_inds.extend(chunk_prior_inds)
if options.get("save_first_pass", True):
save_output(
all_samples,
save_path=os.path.join(save_path, "first-pass"),
video_save_fps=5,
)
video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4")
yield video_path_0
# ---------------------------------- second pass ----------------------------------
prior_indices = image_cond["prior_indices"]
assert (
prior_indices is not None
), "`prior_frame_indices` needs to be set if using 2-pass sampling."
prior_argsort = np.argsort(input_indices + prior_indices).tolist()
prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist()
gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])]
traj_prior_imgs = torch.cat(
[input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0
)[prior_argsort]
traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort]
traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort]
update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs)
update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws)
update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks)
chunk_strategy = options.get("chunk_strategy", "nearest")
(
_,
prior_inds_per_chunk,
prior_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
) = chunk_input_and_test(
T_second_pass,
traj_prior_c2ws,
test_c2ws,
prior_indices,
test_indices,
options=options,
task=task,
chunk_strategy=chunk_strategy,
gt_input_inds=gt_input_inds,
)
print(
f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total "
f"{len(prior_inds_per_chunk)} forward(s) ..."
)
all_samples = {}
all_test_inds = []
for i, (
chunk_prior_inds,
chunk_prior_sels,
chunk_test_inds,
chunk_test_sels,
) in tqdm(
enumerate(
zip(
prior_inds_per_chunk,
prior_sels_per_chunk,
test_inds_per_chunk,
test_sels_per_chunk,
)
),
total=len(prior_inds_per_chunk),
leave=False,
):
(
curr_prior_sels,
curr_test_sels,
curr_prior_maps,
curr_test_maps,
) = pad_indices(
chunk_prior_sels,
chunk_test_sels,
T=T_second_pass,
padding_mode="last",
)
curr_imgs, curr_c2ws, curr_Ks = [
assemble(
input=x[chunk_prior_inds],
test=y[chunk_test_inds],
input_maps=curr_prior_maps,
test_maps=curr_test_maps,
)
for x, y in zip(
[
traj_prior_imgs,
traj_prior_c2ws,
traj_prior_Ks,
],
[test_imgs, test_c2ws, test_Ks],
)
]
value_dict = get_value_dict(
curr_imgs.to("cuda"),
curr_prior_sels,
curr_c2ws,
curr_Ks,
list(range(T_second_pass)),
all_c2ws=camera_cond["c2w"],
camera_scale=options.get("camera_scale", 2.0),
)
samples = do_sample(
model,
ae,
conditioner,
denoiser,
samplers[1] if len(samplers) > 1 else samplers[0],
value_dict,
H,
W,
C,
F,
T=T_second_pass,
cfg=(
options["cfg"][1]
if isinstance(options["cfg"], (list, tuple))
and len(options["cfg"]) > 1
else options["cfg"]
),
global_pbar=second_pass_pbar,
**{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
)
if samples is None:
return
samples = decode_output(
samples, T_second_pass, chunk_test_sels
) # decode into dict
if options.get("save_second_pass", False):
save_output(
replace_or_include_input_for_dict(
samples,
chunk_test_sels,
curr_imgs,
curr_c2ws,
curr_Ks,
),
save_path=os.path.join(save_path, "second-pass", f"forward_{i}"),
video_save_fps=2,
)
extend_dict(all_samples, samples)
all_test_inds.extend(chunk_test_inds)
all_samples = {
key: value[np.argsort(all_test_inds)] for key, value in all_samples.items()
}
save_output(
replace_or_include_input_for_dict(
all_samples,
test_indices,
imgs.clone(),
camera_cond["c2w"].clone(),
camera_cond["K"].clone(),
)
if options.get("replace_or_include_input", False)
else all_samples,
save_path=save_path,
video_save_fps=options.get("video_save_fps", 2),
)
video_path_1 = os.path.join(save_path, "samples-rgb.mp4")
yield video_path_1