import collections
import json
import math
import os
import re
import threading
from typing import List, Literal, Optional, Tuple, Union

import gradio as gr
from colorama import Fore, Style, init

init(autoreset=True)

import imageio.v3 as iio
import numpy as np
import torch
import torch.nn.functional as F
import torchvision.transforms.functional as TF
from einops import repeat
from PIL import Image
from tqdm.auto import tqdm

from seva.geometry import get_camera_dist, get_plucker_coordinates, to_hom_pose
from seva.sampling import (
    EulerEDMSampler,
    MultiviewCFG,
    MultiviewTemporalCFG,
    VanillaCFG,
)
from seva.utils import seed_everything

try:
    # Check if version string contains 'dev' or 'nightly'
    version = torch.__version__
    IS_TORCH_NIGHTLY = "dev" in version
    if IS_TORCH_NIGHTLY:
        torch._dynamo.config.cache_size_limit = 128  # type: ignore[assignment]
        torch._dynamo.config.accumulated_cache_size_limit = 1024  # type: ignore[assignment]
        torch._dynamo.config.force_parameter_static_shapes = False  # type: ignore[assignment]
except Exception:
    IS_TORCH_NIGHTLY = False


def pad_indices(
    input_indices: List[int],
    test_indices: List[int],
    T: int,
    padding_mode: Literal["first", "last", "none"] = "last",
):
    assert padding_mode in ["last", "none"], "`first` padding is not supported yet."
    if padding_mode == "last":
        padded_indices = [
            i for i in range(T) if i not in (input_indices + test_indices)
        ]
    else:
        padded_indices = []
    input_selects = list(range(len(input_indices)))
    test_selects = list(range(len(test_indices)))
    if max(input_indices) > max(test_indices):
        # last elem from input
        input_selects += [input_selects[-1]] * len(padded_indices)
        input_indices = input_indices + padded_indices
        sorted_inds = np.argsort(input_indices)
        input_indices = [input_indices[ind] for ind in sorted_inds]
        input_selects = [input_selects[ind] for ind in sorted_inds]
    else:
        # last elem from test
        test_selects += [test_selects[-1]] * len(padded_indices)
        test_indices = test_indices + padded_indices
        sorted_inds = np.argsort(test_indices)
        test_indices = [test_indices[ind] for ind in sorted_inds]
        test_selects = [test_selects[ind] for ind in sorted_inds]

    if padding_mode == "last":
        input_maps = np.array([-1] * T)
        test_maps = np.array([-1] * T)
    else:
        input_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
        test_maps = np.array([-1] * (len(input_indices) + len(test_indices)))
    input_maps[input_indices] = input_selects
    test_maps[test_indices] = test_selects
    return input_indices, test_indices, input_maps, test_maps


def assemble(
    input,
    test,
    input_maps,
    test_maps,
):
    T = len(input_maps)
    assembled = torch.zeros_like(test[-1:]).repeat_interleave(T, dim=0)
    assembled[input_maps != -1] = input[input_maps[input_maps != -1]]
    assembled[test_maps != -1] = test[test_maps[test_maps != -1]]
    assert np.logical_xor(input_maps != -1, test_maps != -1).all()
    return assembled


def get_resizing_factor(
    target_shape: Tuple[int, int],  # H, W
    current_shape: Tuple[int, int],  # H, W
    cover_target: bool = True,
    # If True, the output shape will fully cover the target shape.
    # If No, the target shape will fully cover the output shape.
) -> float:
    r_bound = target_shape[1] / target_shape[0]
    aspect_r = current_shape[1] / current_shape[0]
    if r_bound >= 1.0:
        if cover_target:
            if aspect_r >= r_bound:
                factor = min(target_shape) / min(current_shape)
            elif aspect_r < 1.0:
                factor = max(target_shape) / min(current_shape)
            else:
                factor = max(target_shape) / max(current_shape)
        else:
            if aspect_r >= r_bound:
                factor = max(target_shape) / max(current_shape)
            elif aspect_r < 1.0:
                factor = min(target_shape) / max(current_shape)
            else:
                factor = min(target_shape) / min(current_shape)
    else:
        if cover_target:
            if aspect_r <= r_bound:
                factor = min(target_shape) / min(current_shape)
            elif aspect_r > 1.0:
                factor = max(target_shape) / min(current_shape)
            else:
                factor = max(target_shape) / max(current_shape)
        else:
            if aspect_r <= r_bound:
                factor = max(target_shape) / max(current_shape)
            elif aspect_r > 1.0:
                factor = min(target_shape) / max(current_shape)
            else:
                factor = min(target_shape) / min(current_shape)
    return factor


def get_unique_embedder_keys_from_conditioner(conditioner):
    keys = [x.input_key for x in conditioner.embedders if x.input_key is not None]
    keys = [item for sublist in keys for item in sublist]  # Flatten list
    return set(keys)


def get_wh_with_fixed_shortest_side(w, h, size):
    # size is smaller or equal to zero, we return original w h
    if size is None or size <= 0:
        return w, h
    if w < h:
        new_w = size
        new_h = int(size * h / w)
    else:
        new_h = size
        new_w = int(size * w / h)
    return new_w, new_h


def load_img_and_K(
    image_path_or_size: Union[str, torch.Size],
    size: Optional[Union[int, Tuple[int, int]]],
    scale: float = 1.0,
    center: Tuple[float, float] = (0.5, 0.5),
    K: torch.Tensor | None = None,
    size_stride: int = 1,
    center_crop: bool = False,
    image_as_tensor: bool = True,
    context_rgb: np.ndarray | None = None,
    device: str = "cuda",
):
    if isinstance(image_path_or_size, torch.Size):
        image = Image.new("RGBA", image_path_or_size[::-1])
    else:
        image = Image.open(image_path_or_size).convert("RGBA")

    w, h = image.size
    if size is None:
        size = (w, h)

    image = np.array(image).astype(np.float32) / 255
    if image.shape[-1] == 4:
        rgb, alpha = image[:, :, :3], image[:, :, 3:]
        if context_rgb is not None:
            image = rgb * alpha + context_rgb * (1 - alpha)
        else:
            image = rgb * alpha + (1 - alpha)
    image = image.transpose(2, 0, 1)
    image = torch.from_numpy(image).to(dtype=torch.float32)
    image = image.unsqueeze(0)

    if isinstance(size, (tuple, list)):
        # => if size is a tuple or list, we first rescale to fully cover the `size`
        # area and then crop the `size` area from the rescale image
        W, H = size
    else:
        # => if size is int, we rescale the image to fit the shortest side to size
        # => if size is None, no rescaling is applied
        W, H = get_wh_with_fixed_shortest_side(w, h, size)
    W, H = (
        math.floor(W / size_stride + 0.5) * size_stride,
        math.floor(H / size_stride + 0.5) * size_stride,
    )

    rfs = get_resizing_factor((math.floor(H * scale), math.floor(W * scale)), (h, w))
    resize_size = rh, rw = [int(np.ceil(rfs * s)) for s in (h, w)]
    image = torch.nn.functional.interpolate(
        image, resize_size, mode="area", antialias=False
    )
    if scale < 1.0:
        pw = math.ceil((W - resize_size[1]) * 0.5)
        ph = math.ceil((H - resize_size[0]) * 0.5)
        image = F.pad(image, (pw, pw, ph, ph), "constant", 1.0)

    cy_center = int(center[1] * image.shape[-2])
    cx_center = int(center[0] * image.shape[-1])
    if center_crop:
        side = min(H, W)
        ct = max(0, cy_center - side // 2)
        cl = max(0, cx_center - side // 2)
        ct = min(ct, image.shape[-2] - side)
        cl = min(cl, image.shape[-1] - side)
        image = TF.crop(image, top=ct, left=cl, height=side, width=side)
    else:
        ct = max(0, cy_center - H // 2)
        cl = max(0, cx_center - W // 2)
        ct = min(ct, image.shape[-2] - H)
        cl = min(cl, image.shape[-1] - W)
        image = TF.crop(image, top=ct, left=cl, height=H, width=W)

    if K is not None:
        K = K.clone()
        if torch.all(K[:2, -1] >= 0) and torch.all(K[:2, -1] <= 1):
            K[:2] *= K.new_tensor([rw, rh])[:, None]  # normalized K
        else:
            K[:2] *= K.new_tensor([rw / w, rh / h])[:, None]  # unnormalized K
        K[:2, 2] -= K.new_tensor([cl, ct])

    if image_as_tensor:
        # tensor of shape (1, 3, H, W) with values ranging from (-1, 1)
        image = image.to(device) * 2.0 - 1.0
    else:
        # PIL Image with values ranging from (0, 255)
        image = image.permute(0, 2, 3, 1).numpy()[0]
        image = Image.fromarray((image * 255).astype(np.uint8))
    return image, K


def transform_img_and_K(
    image: torch.Tensor,
    size: Union[int, Tuple[int, int]],
    scale: float = 1.0,
    center: Tuple[float, float] = (0.5, 0.5),
    K: torch.Tensor | None = None,
    size_stride: int = 1,
    mode: str = "crop",
):
    assert mode in [
        "crop",
        "pad",
        "stretch",
    ], f"mode should be one of ['crop', 'pad', 'stretch'], got {mode}"

    h, w = image.shape[-2:]
    if isinstance(size, (tuple, list)):
        # => if size is a tuple or list, we first rescale to fully cover the `size`
        # area and then crop the `size` area from the rescale image
        W, H = size
    else:
        # => if size is int, we rescale the image to fit the shortest side to size
        # => if size is None, no rescaling is applied
        W, H = get_wh_with_fixed_shortest_side(w, h, size)
    W, H = (
        math.floor(W / size_stride + 0.5) * size_stride,
        math.floor(H / size_stride + 0.5) * size_stride,
    )

    if mode == "stretch":
        rh, rw = H, W
    else:
        rfs = get_resizing_factor(
            (H, W),
            (h, w),
            cover_target=mode != "pad",
        )
        (rh, rw) = [int(np.ceil(rfs * s)) for s in (h, w)]

    rh, rw = int(rh / scale), int(rw / scale)
    image = torch.nn.functional.interpolate(
        image, (rh, rw), mode="area", antialias=False
    )

    cy_center = int(center[1] * image.shape[-2])
    cx_center = int(center[0] * image.shape[-1])
    if mode != "pad":
        ct = max(0, cy_center - H // 2)
        cl = max(0, cx_center - W // 2)
        ct = min(ct, image.shape[-2] - H)
        cl = min(cl, image.shape[-1] - W)
        image = TF.crop(image, top=ct, left=cl, height=H, width=W)
        pl, pt = 0, 0
    else:
        pt = max(0, H // 2 - cy_center)
        pl = max(0, W // 2 - cx_center)
        pb = max(0, H - pt - image.shape[-2])
        pr = max(0, W - pl - image.shape[-1])
        image = TF.pad(
            image,
            [pl, pt, pr, pb],
        )
        cl, ct = 0, 0

    if K is not None:
        K = K.clone()
        # K[:, :2, 2] += K.new_tensor([pl, pt])
        if torch.all(K[:, :2, -1] >= 0) and torch.all(K[:, :2, -1] <= 1):
            K[:, :2] *= K.new_tensor([rw, rh])[None, :, None]  # normalized K
        else:
            K[:, :2] *= K.new_tensor([rw / w, rh / h])[None, :, None]  # unnormalized K
        K[:, :2, 2] += K.new_tensor([pl - cl, pt - ct])

    return image, K


lowvram_mode = False


def set_lowvram_mode(mode):
    global lowvram_mode
    lowvram_mode = mode


def load_model(model, device: str = "cuda"):
    model.to(device)


def unload_model(model):
    global lowvram_mode
    if lowvram_mode:
        model.cpu()
        torch.cuda.empty_cache()


def infer_prior_stats(
    T,
    num_input_frames,
    num_total_frames,
    version_dict,
):
    options = version_dict["options"]
    chunk_strategy = options.get("chunk_strategy", "nearest")
    T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
    T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
    # get traj_prior_c2ws for 2-pass sampling
    if chunk_strategy.startswith("interp"):
        # Start and end have alreay taken up two slots
        # +1 means we need X + 1 prior frames to bound X times forwards for all test frames

        # Tuning up `num_prior_frames_ratio` is helpful when you observe sudden jump in the
        # generated frames due to insufficient prior frames. This option is effective for
        # complicated trajectory and when `interp` strategy is used (usually semi-dense-view
        # regime). Recommended range is [1.0 (default), 1.5].
        if num_input_frames >= options.get("num_input_semi_dense", 9):
            num_prior_frames = (
                math.ceil(
                    num_total_frames
                    / (T_second_pass - 2)
                    * options.get("num_prior_frames_ratio", 1.0)
                )
                + 1
            )

            if num_prior_frames + num_input_frames < T_first_pass:
                num_prior_frames = T_first_pass - num_input_frames

            num_prior_frames = max(
                num_prior_frames,
                options.get("num_prior_frames", 0),
            )

            T_first_pass = num_prior_frames + num_input_frames

            if "gt" in chunk_strategy:
                T_second_pass = T_second_pass + num_input_frames

            # Dynamically update context window length.
            version_dict["T"] = [T_first_pass, T_second_pass]

        else:
            num_prior_frames = (
                math.ceil(
                    num_total_frames
                    / (
                        T_second_pass
                        - 2
                        - (num_input_frames if "gt" in chunk_strategy else 0)
                    )
                    * options.get("num_prior_frames_ratio", 1.0)
                )
                + 1
            )

            if num_prior_frames + num_input_frames < T_first_pass:
                num_prior_frames = T_first_pass - num_input_frames

            num_prior_frames = max(
                num_prior_frames,
                options.get("num_prior_frames", 0),
            )
    else:
        num_prior_frames = max(
            T_first_pass - num_input_frames,
            options.get("num_prior_frames", 0),
        )

        if num_input_frames >= options.get("num_input_semi_dense", 9):
            T_first_pass = num_prior_frames + num_input_frames

            # Dynamically update context window length.
            version_dict["T"] = [T_first_pass, T_second_pass]

    return num_prior_frames


def infer_prior_inds(
    c2ws,
    num_prior_frames,
    input_frame_indices,
    options,
):
    chunk_strategy = options.get("chunk_strategy", "nearest")
    if chunk_strategy.startswith("interp"):
        prior_frame_indices = np.array(
            [i for i in range(c2ws.shape[0]) if i not in input_frame_indices]
        )
        prior_frame_indices = prior_frame_indices[
            np.ceil(
                np.linspace(
                    0, prior_frame_indices.shape[0] - 1, num_prior_frames, endpoint=True
                )
            ).astype(int)
        ]  # having a ceil here is actually safer for corner case
    else:
        prior_frame_indices = []
        while len(prior_frame_indices) < num_prior_frames:
            closest_distance = np.abs(
                np.arange(c2ws.shape[0])[None]
                - np.concatenate(
                    [np.array(input_frame_indices), np.array(prior_frame_indices)]
                )[:, None]
            ).min(0)
            prior_frame_indices.append(np.argsort(closest_distance)[-1])
    return np.sort(prior_frame_indices)


def compute_relative_inds(
    source_inds,
    target_inds,
):
    assert len(source_inds) > 2
    # compute relative indices of target_inds within source_inds
    relative_inds = []
    for ind in target_inds:
        if ind in source_inds:
            relative_ind = int(np.where(source_inds == ind)[0][0])
        elif ind < source_inds[0]:
            # extrapolate
            relative_ind = -((source_inds[0] - ind) / (source_inds[1] - source_inds[0]))
        elif ind > source_inds[-1]:
            # extrapolate
            relative_ind = len(source_inds) + (
                (ind - source_inds[-1]) / (source_inds[-1] - source_inds[-2])
            )
        else:
            # interpolate
            lower_inds = source_inds[source_inds < ind]
            upper_inds = source_inds[source_inds > ind]
            if len(lower_inds) > 0 and len(upper_inds) > 0:
                lower_ind = lower_inds[-1]
                upper_ind = upper_inds[0]
                relative_lower_ind = int(np.where(source_inds == lower_ind)[0][0])
                relative_upper_ind = int(np.where(source_inds == upper_ind)[0][0])
                relative_ind = relative_lower_ind + (ind - lower_ind) / (
                    upper_ind - lower_ind
                ) * (relative_upper_ind - relative_lower_ind)
            else:
                # Out of range
                relative_inds.append(float("nan"))  # Or some other placeholder
        relative_inds.append(relative_ind)
    return relative_inds


def find_nearest_source_inds(
    source_c2ws,
    target_c2ws,
    nearest_num=1,
    mode="translation",
):
    dists = get_camera_dist(source_c2ws, target_c2ws, mode=mode).cpu().numpy()
    sorted_inds = np.argsort(dists, axis=0).T
    return sorted_inds[:, :nearest_num]


def chunk_input_and_test(
    T,
    input_c2ws,
    test_c2ws,
    input_ords,  # orders
    test_ords,  # orders
    options,
    task: str = "img2img",
    chunk_strategy: str = "gt",
    gt_input_inds: list = [],
):
    M, N = input_c2ws.shape[0], test_c2ws.shape[0]

    chunks = []
    if chunk_strategy.startswith("gt"):
        assert len(gt_input_inds) < T, (
            f"Number of gt input frames {len(gt_input_inds)} should be "
            f"less than {T} when `gt` chunking strategy is used."
        )
        assert (
            list(range(M)) == gt_input_inds
        ), "All input_c2ws should be gt when `gt` chunking strategy is used."

        # LEGACY CHUNKING STRATEGY
        # num_test_per_chunk = T - len(gt_input_inds)
        # test_inds_per_chunk = [i for i in range(T) if i not in gt_input_inds]
        # for i in range(0, test_c2ws.shape[0], num_test_per_chunk):
        #     chunk = ["NULL"] * T
        #     for j, k in enumerate(gt_input_inds):
        #         chunk[k] = f"!{j:03d}"
        #     for j, k in enumerate(
        #         test_inds_per_chunk[: test_c2ws[i : i + num_test_per_chunk].shape[0]]
        #     ):
        #         chunk[k] = f">{i + j:03d}"
        #     chunks.append(chunk)

        num_test_seen = 0
        while num_test_seen < N:
            chunk = [f"!{i:03d}" for i in gt_input_inds]
            if chunk_strategy != "gt" and num_test_seen > 0:
                pseudo_num_ratio = options.get("pseudo_num_ratio", 0.33)
                if (N - num_test_seen) >= math.floor(
                    (T - len(gt_input_inds)) * pseudo_num_ratio
                ):
                    pseudo_num = math.ceil((T - len(gt_input_inds)) * pseudo_num_ratio)
                else:
                    pseudo_num = (T - len(gt_input_inds)) - (N - num_test_seen)
                pseudo_num = min(pseudo_num, options.get("pseudo_num_max", 10000))

                if "ltr" in chunk_strategy:
                    chunk.extend(
                        [
                            f"!{i + len(gt_input_inds):03d}"
                            for i in range(num_test_seen - pseudo_num, num_test_seen)
                        ]
                    )
                elif "nearest" in chunk_strategy:
                    source_inds = np.concatenate(
                        [
                            find_nearest_source_inds(
                                test_c2ws[:num_test_seen],
                                test_c2ws[num_test_seen:],
                                nearest_num=1,  # pseudo_num,
                                mode="rotation",
                            ),
                            find_nearest_source_inds(
                                test_c2ws[:num_test_seen],
                                test_c2ws[num_test_seen:],
                                nearest_num=1,  # pseudo_num,
                                mode="translation",
                            ),
                        ],
                        axis=1,
                    )
                    ####### [HACK ALERT] keep running until pseudo num is stablized ########
                    temp_pseudo_num = pseudo_num
                    while True:
                        nearest_source_inds = np.concatenate(
                            [
                                np.sort(
                                    [
                                        ind
                                        for (ind, _) in collections.Counter(
                                            [
                                                item
                                                for item in source_inds[
                                                    : T
                                                    - len(gt_input_inds)
                                                    - temp_pseudo_num
                                                ]
                                                .flatten()
                                                .tolist()
                                                if item
                                                != (
                                                    num_test_seen - 1
                                                )  # exclude the last one here
                                            ]
                                        ).most_common(pseudo_num - 1)
                                    ],
                                ).astype(int),
                                [num_test_seen - 1],  # always keep the last one
                            ]
                        )
                        if len(nearest_source_inds) >= temp_pseudo_num:
                            break  # stablized
                        else:
                            temp_pseudo_num = len(nearest_source_inds)
                    pseudo_num = len(nearest_source_inds)
                    ########################################################################
                    chunk.extend(
                        [f"!{i + len(gt_input_inds):03d}" for i in nearest_source_inds]
                    )
                else:
                    raise NotImplementedError(
                        f"Chunking strategy {chunk_strategy} for the first pass is not implemented."
                    )

                chunk.extend(
                    [
                        f">{i:03d}"
                        for i in range(
                            num_test_seen,
                            min(num_test_seen + T - len(gt_input_inds) - pseudo_num, N),
                        )
                    ]
                )
            else:
                chunk.extend(
                    [
                        f">{i:03d}"
                        for i in range(
                            num_test_seen,
                            min(num_test_seen + T - len(gt_input_inds), N),
                        )
                    ]
                )

            num_test_seen += sum([1 for c in chunk if c.startswith(">")])
            if len(chunk) < T:
                chunk.extend(["NULL"] * (T - len(chunk)))
            chunks.append(chunk)

    elif chunk_strategy.startswith("nearest"):
        input_imgs = np.array([f"!{i:03d}" for i in range(M)])
        test_imgs = np.array([f">{i:03d}" for i in range(N)])

        match = re.match(r"^nearest-(\d+)$", chunk_strategy)
        if match:
            nearest_num = int(match.group(1))
            assert (
                nearest_num < T
            ), f"Nearest number of {nearest_num} should be less than {T}."
            source_inds = find_nearest_source_inds(
                input_c2ws,
                test_c2ws,
                nearest_num=nearest_num,
                mode="translation",  # during the second pass, consider translation only is enough
            )

            for i in range(0, N, T - nearest_num):
                nearest_source_inds = np.sort(
                    [
                        ind
                        for (ind, _) in collections.Counter(
                            source_inds[i : i + T - nearest_num].flatten().tolist()
                        ).most_common(nearest_num)
                    ]
                )
                chunk = (
                    input_imgs[nearest_source_inds].tolist()
                    + test_imgs[i : i + T - nearest_num].tolist()
                )
                chunks.append(chunk + ["NULL"] * (T - len(chunk)))

        else:
            # do not always condition on gt cond frames
            if "gt" not in chunk_strategy:
                gt_input_inds = []

            source_inds = find_nearest_source_inds(
                input_c2ws,
                test_c2ws,
                nearest_num=1,
                mode="translation",  # during the second pass, consider translation only is enough
            )[:, 0]

            test_inds_per_input = {}
            for test_idx, input_idx in enumerate(source_inds):
                if input_idx not in test_inds_per_input:
                    test_inds_per_input[input_idx] = []
                test_inds_per_input[input_idx].append(test_idx)

            num_test_seen = 0
            chunk = input_imgs[gt_input_inds].tolist()
            candidate_input_inds = sorted(list(test_inds_per_input.keys()))

            while num_test_seen < N:
                input_idx = candidate_input_inds[0]
                test_inds = test_inds_per_input[input_idx]
                input_is_cond = input_idx in gt_input_inds
                prefix_inds = [] if input_is_cond else [input_idx]

                if len(chunk) == T - len(prefix_inds) or not candidate_input_inds:
                    if chunk:
                        chunk += ["NULL"] * (T - len(chunk))
                        chunks.append(chunk)
                        chunk = input_imgs[gt_input_inds].tolist()
                    if num_test_seen >= N:
                        break
                    continue

                candidate_chunk = (
                    input_imgs[prefix_inds].tolist() + test_imgs[test_inds].tolist()
                )

                space_left = T - len(chunk)
                if len(candidate_chunk) <= space_left:
                    chunk.extend(candidate_chunk)
                    num_test_seen += len(test_inds)
                    candidate_input_inds.pop(0)
                else:
                    chunk.extend(candidate_chunk[:space_left])
                    num_input_idx = 0 if input_is_cond else 1
                    num_test_seen += space_left - num_input_idx
                    test_inds_per_input[input_idx] = test_inds[
                        space_left - num_input_idx :
                    ]

                if len(chunk) == T:
                    chunks.append(chunk)
                    chunk = input_imgs[gt_input_inds].tolist()

            if chunk and chunk != input_imgs[gt_input_inds].tolist():
                chunks.append(chunk + ["NULL"] * (T - len(chunk)))

    elif chunk_strategy.startswith("interp"):
        # `interp` chunk requires ordering info
        assert input_ords is not None and test_ords is not None, (
            "When using `interp` chunking strategy, ordering of input "
            "and test frames should be provided."
        )

        # if chunk_strategy is `interp*`` and task is `img2trajvid*`, we will not
        # use input views since their order info within target views is unknown
        if "img2trajvid" in task:
            assert (
                list(range(len(gt_input_inds))) == gt_input_inds
            ), "`img2trajvid` task should put `gt_input_inds` in start."
            input_c2ws = input_c2ws[
                [ind for ind in range(M) if ind not in gt_input_inds]
            ]
            input_ords = [
                input_ords[ind] for ind in range(M) if ind not in gt_input_inds
            ]
            M = input_c2ws.shape[0]

        input_ords = [0] + input_ords  # this is a  hack accounting for test views
        # before the first input view
        input_ords[-1] += 0.01  # this is a hack ensuring last test stop is included
        # in the last forward when input_ords[-1] == test_ords[-1]
        input_ords = np.array(input_ords)[:, None]
        input_ords_ = np.concatenate([input_ords[1:], np.full((1, 1), np.inf)])
        test_ords = np.array(test_ords)[None]

        in_stop_ranges = np.logical_and(
            np.repeat(input_ords, N, axis=1) <= np.repeat(test_ords, M + 1, axis=0),
            np.repeat(input_ords_, N, axis=1) > np.repeat(test_ords, M + 1, axis=0),
        )  # (M, N)
        assert (in_stop_ranges.sum(1) <= T - 2).all(), (
            "More input frames need to be sampled during the first pass to ensure "
            f"#test frames during each forard in the second pass will not exceed {T - 2}."
        )
        if input_ords[1, 0] <= test_ords[0, 0]:
            assert not in_stop_ranges[0].any()
        if input_ords[-1, 0] >= test_ords[0, -1]:
            assert not in_stop_ranges[-1].any()

        gt_chunk = (
            [f"!{i:03d}" for i in gt_input_inds] if "gt" in chunk_strategy else []
        )
        chunk = gt_chunk + []
        # any test views before the first input views
        if in_stop_ranges[0].any():
            for j, in_range in enumerate(in_stop_ranges[0]):
                if in_range:
                    chunk.append(f">{j:03d}")
        in_stop_ranges = in_stop_ranges[1:]

        i = 0
        base_i = len(gt_input_inds) if "img2trajvid" in task else 0
        chunk.append(f"!{i + base_i:03d}")
        while i < len(in_stop_ranges):
            in_stop_range = in_stop_ranges[i]
            if not in_stop_range.any():
                i += 1
                continue

            input_left = i + 1 < M
            space_left = T - len(chunk)
            if sum(in_stop_range) + input_left <= space_left:
                for j, in_range in enumerate(in_stop_range):
                    if in_range:
                        chunk.append(f">{j:03d}")
                i += 1
                if input_left:
                    chunk.append(f"!{i + base_i:03d}")

            else:
                chunk += ["NULL"] * space_left
                chunks.append(chunk)
                chunk = gt_chunk + [f"!{i + base_i:03d}"]

        if len(chunk) > 1:
            chunk += ["NULL"] * (T - len(chunk))
            chunks.append(chunk)

    else:
        raise NotImplementedError

    (
        input_inds_per_chunk,
        input_sels_per_chunk,
        test_inds_per_chunk,
        test_sels_per_chunk,
    ) = (
        [],
        [],
        [],
        [],
    )
    for chunk in chunks:
        input_inds = [
            int(img.removeprefix("!")) for img in chunk if img.startswith("!")
        ]
        input_sels = [chunk.index(img) for img in chunk if img.startswith("!")]
        test_inds = [int(img.removeprefix(">")) for img in chunk if img.startswith(">")]
        test_sels = [chunk.index(img) for img in chunk if img.startswith(">")]
        input_inds_per_chunk.append(input_inds)
        input_sels_per_chunk.append(input_sels)
        test_inds_per_chunk.append(test_inds)
        test_sels_per_chunk.append(test_sels)

    if options.get("sampler_verbose", True):

        def colorize(item):
            if item.startswith("!"):
                return f"{Fore.RED}{item}{Style.RESET_ALL}"  # Red for items starting with '!'
            elif item.startswith(">"):
                return f"{Fore.GREEN}{item}{Style.RESET_ALL}"  # Green for items starting with '>'
            return item  # Default color if neither '!' nor '>'

        print("\nchunks:")
        for chunk in chunks:
            print(", ".join(colorize(item) for item in chunk))

    return (
        chunks,
        input_inds_per_chunk,  # ordering of input in raw sequence
        input_sels_per_chunk,  # ordering of input in one-forward sequence of length T
        test_inds_per_chunk,  # ordering of test in raw sequence
        test_sels_per_chunk,  # oredering of test in one-forward sequence of length T
    )


def is_k_in_dict(d, k):
    return any(map(lambda x: x.startswith(k), d.keys()))


def get_k_from_dict(d, k):
    media_d = {}
    for key, value in d.items():
        if key == k:
            return value
        if key.startswith(k):
            media = key.split("/")[-1]
            if media == "raw":
                return value
            media_d[media] = value
    if len(media_d) == 0:
        return torch.tensor([])
    assert (
        len(media_d) == 1
    ), f"multiple media found in {d} for key {k}: {media_d.keys()}"
    return media_d[media]


def update_kv_for_dict(d, k, v):
    for key in d.keys():
        if key.startswith(k):
            d[key] = v
    return d


def extend_dict(ds, d):
    for key in d.keys():
        if key in ds:
            ds[key] = torch.cat([ds[key], d[key]], 0)
        else:
            ds[key] = d[key]
    return ds


def replace_or_include_input_for_dict(
    samples,
    test_indices,
    imgs,
    c2w,
    K,
):
    samples_new = {}
    for sample, value in samples.items():
        if "rgb" in sample:
            imgs[test_indices] = (
                value[test_indices] if value.shape[0] == imgs.shape[0] else value
            ).to(device=imgs.device, dtype=imgs.dtype)
            samples_new[sample] = imgs
        elif "c2w" in sample:
            c2w[test_indices] = (
                value[test_indices] if value.shape[0] == c2w.shape[0] else value
            ).to(device=c2w.device, dtype=c2w.dtype)
            samples_new[sample] = c2w
        elif "intrinsics" in sample:
            K[test_indices] = (
                value[test_indices] if value.shape[0] == K.shape[0] else value
            ).to(device=K.device, dtype=K.dtype)
            samples_new[sample] = K
        else:
            samples_new[sample] = value
    return samples_new


def decode_output(
    samples,
    T,
    indices=None,
):
    # decode model output into dict if it is not
    if isinstance(samples, dict):
        # model with postprocessor and outputs dict
        for sample, value in samples.items():
            if isinstance(value, torch.Tensor):
                value = value.detach().cpu()
            elif isinstance(value, np.ndarray):
                value = torch.from_numpy(value)
            else:
                value = torch.tensor(value)

            if indices is not None and value.shape[0] == T:
                value = value[indices]
            samples[sample] = value
    else:
        # model without postprocessor and outputs tensor (rgb)
        samples = samples.detach().cpu()

        if indices is not None and samples.shape[0] == T:
            samples = samples[indices]
        samples = {"samples-rgb/image": samples}

    return samples


def save_output(
    samples,
    save_path,
    video_save_fps=2,
):
    os.makedirs(save_path, exist_ok=True)
    for sample in samples:
        media_type = "video"
        if "/" in sample:
            sample_, media_type = sample.split("/")
        else:
            sample_ = sample

        value = samples[sample]
        if isinstance(value, torch.Tensor):
            value = value.detach().cpu()
        elif isinstance(value, np.ndarray):
            value = torch.from_numpy(value)
        else:
            value = torch.tensor(value)

        if media_type == "image":
            value = (value.permute(0, 2, 3, 1) + 1) / 2.0
            value = (value * 255).clamp(0, 255).to(torch.uint8)
            iio.imwrite(
                os.path.join(save_path, f"{sample_}.mp4")
                if sample_
                else f"{save_path}.mp4",
                value,
                fps=video_save_fps,
                macro_block_size=1,
                ffmpeg_log_level="error",
            )
            os.makedirs(os.path.join(save_path, sample_), exist_ok=True)
            for i, s in enumerate(value):
                iio.imwrite(
                    os.path.join(save_path, sample_, f"{i:03d}.png"),
                    s,
                )
        elif media_type == "video":
            value = (value.permute(0, 2, 3, 1) + 1) / 2.0
            value = (value * 255).clamp(0, 255).to(torch.uint8)
            iio.imwrite(
                os.path.join(save_path, f"{sample_}.mp4"),
                value,
                fps=video_save_fps,
                macro_block_size=1,
                ffmpeg_log_level="error",
            )
        elif media_type == "raw":
            torch.save(
                value,
                os.path.join(save_path, f"{sample_}.pt"),
            )
        else:
            pass


def create_transforms_simple(save_path, img_paths, img_whs, c2ws, Ks):
    import os.path as osp

    out_frames = []
    for img_path, img_wh, c2w, K in zip(img_paths, img_whs, c2ws, Ks):
        out_frame = {
            "fl_x": K[0][0].item(),
            "fl_y": K[1][1].item(),
            "cx": K[0][2].item(),
            "cy": K[1][2].item(),
            "w": img_wh[0].item(),
            "h": img_wh[1].item(),
            "file_path": f"./{osp.relpath(img_path, start=save_path)}"
            if img_path is not None
            else None,
            "transform_matrix": c2w.tolist(),
        }
        out_frames.append(out_frame)
    out = {
        # "camera_model": "PINHOLE",
        "orientation_override": "none",
        "frames": out_frames,
    }
    with open(osp.join(save_path, "transforms.json"), "w") as of:
        json.dump(out, of, indent=5)


class GradioTrackedSampler(EulerEDMSampler):
    """
    A thin wrapper around the EulerEDMSampler that allows tracking progress and
    aborting sampling for gradio demo.
    """

    def __init__(self, abort_event: threading.Event, *args, **kwargs):
        super().__init__(*args, **kwargs)
        self.abort_event = abort_event

    def __call__(  # type: ignore
        self,
        denoiser,
        x: torch.Tensor,
        scale: float | torch.Tensor,
        cond: dict,
        uc: dict | None = None,
        num_steps: int | None = None,
        verbose: bool = True,
        global_pbar: gr.Progress | None = None,
        **guider_kwargs,
    ) -> torch.Tensor | None:
        uc = cond if uc is None else uc
        x, s_in, sigmas, num_sigmas, cond, uc = self.prepare_sampling_loop(
            x,
            cond,
            uc,
            num_steps,
        )
        for i in self.get_sigma_gen(num_sigmas, verbose=verbose):
            gamma = (
                min(self.s_churn / (num_sigmas - 1), 2**0.5 - 1)
                if self.s_tmin <= sigmas[i] <= self.s_tmax
                else 0.0
            )
            x = self.sampler_step(
                s_in * sigmas[i],
                s_in * sigmas[i + 1],
                denoiser,
                x,
                scale,
                cond,
                uc,
                gamma,
                **guider_kwargs,
            )
            # Allow tracking progress in gradio demo.
            if global_pbar is not None:
                global_pbar.update()
            # Allow aborting sampling in gradio demo.
            if self.abort_event.is_set():
                return None
        return x


def create_samplers(
    guider_types: int | list[int],
    discretization,
    num_frames: list[int] | None,
    num_steps: int,
    cfg_min: float = 1.0,
    device: str | torch.device = "cuda",
    abort_event: threading.Event | None = None,
):
    guider_mapping = {
        0: VanillaCFG,
        1: MultiviewCFG,
        2: MultiviewTemporalCFG,
    }
    samplers = []
    if not isinstance(guider_types, (list, tuple)):
        guider_types = [guider_types]
    for i, guider_type in enumerate(guider_types):
        if guider_type not in guider_mapping:
            raise ValueError(
                f"Invalid guider type {guider_type}. Must be one of {list(guider_mapping.keys())}"
            )
        guider_cls = guider_mapping[guider_type]
        guider_args = ()
        if guider_type > 0:
            guider_args += (cfg_min,)
            if guider_type == 2:
                assert num_frames is not None
                guider_args = (num_frames[i], cfg_min)
        guider = guider_cls(*guider_args)

        if abort_event is not None:
            sampler = GradioTrackedSampler(
                abort_event,
                discretization=discretization,
                guider=guider,
                num_steps=num_steps,
                s_churn=0.0,
                s_tmin=0.0,
                s_tmax=999.0,
                s_noise=1.0,
                verbose=True,
                device=device,
            )
        else:
            sampler = EulerEDMSampler(
                discretization=discretization,
                guider=guider,
                num_steps=num_steps,
                s_churn=0.0,
                s_tmin=0.0,
                s_tmax=999.0,
                s_noise=1.0,
                verbose=True,
                device=device,
            )
        samplers.append(sampler)
    return samplers


def get_value_dict(
    curr_imgs,
    curr_imgs_clip,
    curr_input_frame_indices,
    curr_c2ws,
    curr_Ks,
    curr_input_camera_indices,
    all_c2ws,
    camera_scale,
):
    assert sorted(curr_input_camera_indices) == sorted(
        range(len(curr_input_camera_indices))
    )
    H, W, T, F = curr_imgs.shape[-2], curr_imgs.shape[-1], len(curr_imgs), 8

    value_dict = {}
    value_dict["cond_frames_without_noise"] = curr_imgs_clip[curr_input_frame_indices]
    value_dict["cond_frames"] = curr_imgs + 0.0 * torch.randn_like(curr_imgs)
    value_dict["cond_frames_mask"] = torch.zeros(T, dtype=torch.bool)
    value_dict["cond_frames_mask"][curr_input_frame_indices] = True
    value_dict["cond_aug"] = 0.0

    c2w = to_hom_pose(curr_c2ws.float())
    w2c = torch.linalg.inv(c2w)

    # camera centering
    ref_c2ws = all_c2ws
    camera_dist_2med = torch.norm(
        ref_c2ws[:, :3, 3] - ref_c2ws[:, :3, 3].median(0, keepdim=True).values,
        dim=-1,
    )
    valid_mask = camera_dist_2med <= torch.clamp(
        torch.quantile(camera_dist_2med, 0.97) * 10,
        max=1e6,
    )
    c2w[:, :3, 3] -= ref_c2ws[valid_mask, :3, 3].mean(0, keepdim=True)
    w2c = torch.linalg.inv(c2w)

    # camera normalization
    camera_dists = c2w[:, :3, 3].clone()
    translation_scaling_factor = (
        camera_scale
        if torch.isclose(
            torch.norm(camera_dists[0]),
            torch.zeros(1),
            atol=1e-5,
        ).any()
        else (camera_scale / torch.norm(camera_dists[0]))
    )
    w2c[:, :3, 3] *= translation_scaling_factor
    c2w[:, :3, 3] *= translation_scaling_factor
    value_dict["plucker_coordinate"], _ = get_plucker_coordinates(
        extrinsics_src=w2c[0],
        extrinsics=w2c,
        intrinsics=curr_Ks.float().clone(),
        mode="plucker",
        rel_zero_translation=True,
        target_size=(H // F, W // F),
        return_grid_cam=True,
    )

    value_dict["c2w"] = c2w
    value_dict["K"] = curr_Ks
    value_dict["camera_mask"] = torch.zeros(T, dtype=torch.bool)
    value_dict["camera_mask"][curr_input_camera_indices] = True

    return value_dict


def do_sample(
    model,
    ae,
    conditioner,
    denoiser,
    sampler,
    value_dict,
    H,
    W,
    C,
    F,
    T,
    cfg,
    encoding_t=1,
    decoding_t=1,
    verbose=True,
    global_pbar=None,
    **_,
):
    imgs = value_dict["cond_frames"].to("cuda")
    input_masks = value_dict["cond_frames_mask"].to("cuda")
    pluckers = value_dict["plucker_coordinate"].to("cuda")

    num_samples = [1, T]
    with torch.inference_mode(), torch.autocast("cuda"):
        load_model(ae)
        load_model(conditioner)
        latents = torch.nn.functional.pad(
            ae.encode(imgs[input_masks], encoding_t), (0, 0, 0, 0, 0, 1), value=1.0
        )
        c_crossattn = repeat(conditioner(imgs[input_masks]).mean(0), "d -> n 1 d", n=T)
        uc_crossattn = torch.zeros_like(c_crossattn)
        c_replace = latents.new_zeros(T, *latents.shape[1:])
        c_replace[input_masks] = latents
        uc_replace = torch.zeros_like(c_replace)
        c_concat = torch.cat(
            [
                repeat(
                    input_masks,
                    "n -> n 1 h w",
                    h=pluckers.shape[2],
                    w=pluckers.shape[3],
                ),
                pluckers,
            ],
            1,
        )
        uc_concat = torch.cat(
            [pluckers.new_zeros(T, 1, *pluckers.shape[-2:]), pluckers], 1
        )
        c_dense_vector = pluckers
        uc_dense_vector = c_dense_vector
        c = {
            "crossattn": c_crossattn,
            "replace": c_replace,
            "concat": c_concat,
            "dense_vector": c_dense_vector,
        }
        uc = {
            "crossattn": uc_crossattn,
            "replace": uc_replace,
            "concat": uc_concat,
            "dense_vector": uc_dense_vector,
        }
        unload_model(ae)
        unload_model(conditioner)

        additional_model_inputs = {"num_frames": T}
        additional_sampler_inputs = {
            "c2w": value_dict["c2w"].to("cuda"),
            "K": value_dict["K"].to("cuda"),
            "input_frame_mask": value_dict["cond_frames_mask"].to("cuda"),
        }
        if global_pbar is not None:
            additional_sampler_inputs["global_pbar"] = global_pbar

        shape = (math.prod(num_samples), C, H // F, W // F)
        randn = torch.randn(shape).to("cuda")

        load_model(model)
        samples_z = sampler(
            lambda input, sigma, c: denoiser(
                model,
                input,
                sigma,
                c,
                **additional_model_inputs,
            ),
            randn,
            scale=cfg,
            cond=c,
            uc=uc,
            verbose=verbose,
            **additional_sampler_inputs,
        )
        if samples_z is None:
            return
        unload_model(model)

        load_model(ae)
        samples = ae.decode(samples_z, decoding_t)
        unload_model(ae)

    return samples


def run_one_scene(
    task,
    version_dict,
    model,
    ae,
    conditioner,
    denoiser,
    image_cond,
    camera_cond,
    save_path,
    use_traj_prior,
    traj_prior_Ks,
    traj_prior_c2ws,
    seed=23,
    gradio=False,
    abort_event=None,
    first_pass_pbar=None,
    second_pass_pbar=None,
):
    H, W, T, C, F, options = (
        version_dict["H"],
        version_dict["W"],
        version_dict["T"],
        version_dict["C"],
        version_dict["f"],
        version_dict["options"],
    )

    if isinstance(image_cond, str):
        image_cond = {"img": [image_cond]}
    imgs_clip, imgs, img_size = [], [], None
    for i, (img, K) in enumerate(zip(image_cond["img"], camera_cond["K"])):
        if isinstance(img, str) or img is None:
            img, K = load_img_and_K(img or img_size, None, K=K, device="cpu")  # type: ignore
            img_size = img.shape[-2:]
            if options.get("L_short", -1) == -1:
                img, K = transform_img_and_K(
                    img,
                    (W, H),
                    K=K[None],
                    mode=(
                        options.get("transform_input", "crop")
                        if i in image_cond["input_indices"]
                        else options.get("transform_target", "crop")
                    ),
                    scale=(
                        1.0
                        if i in image_cond["input_indices"]
                        else options.get("transform_scale", 1.0)
                    ),
                )
            else:
                downsample = 3
                assert options["L_short"] % F * 2**downsample == 0, (
                    "Short side of the image should be divisible by "
                    f"F*2**{downsample}={F * 2**downsample}."
                )
                img, K = transform_img_and_K(
                    img,
                    options["L_short"],
                    K=K[None],
                    size_stride=F * 2**downsample,
                    mode=(
                        options.get("transform_input", "crop")
                        if i in image_cond["input_indices"]
                        else options.get("transform_target", "crop")
                    ),
                    scale=(
                        1.0
                        if i in image_cond["input_indices"]
                        else options.get("transform_scale", 1.0)
                    ),
                )
                version_dict["W"] = W = img.shape[-1]
                version_dict["H"] = H = img.shape[-2]
            K = K[0]
            K[0] /= W
            K[1] /= H
            camera_cond["K"][i] = K
            img_clip = img
        elif isinstance(img, np.ndarray):
            img_size = torch.Size(img.shape[:2])
            img = torch.as_tensor(img).permute(2, 0, 1)
            img = img.unsqueeze(0)
            img = img / 255.0 * 2.0 - 1.0
            if not gradio:
                img, K = transform_img_and_K(img, (W, H), K=K[None])
                assert K is not None
                K = K[0]
            K[0] /= W
            K[1] /= H
            camera_cond["K"][i] = K
            img_clip = img
        else:
            assert (
                False
            ), f"Variable `img` got {type(img)} type which is not supported!!!"
        imgs_clip.append(img_clip)
        imgs.append(img)
    imgs_clip = torch.cat(imgs_clip, dim=0)
    imgs = torch.cat(imgs, dim=0)

    if traj_prior_Ks is not None:
        assert img_size is not None
        for i, prior_k in enumerate(traj_prior_Ks):
            img, prior_k = load_img_and_K(img_size, None, K=prior_k, device="cpu")  # type: ignore
            img, prior_k = transform_img_and_K(
                img,
                (W, H),
                K=prior_k[None],
                mode=options.get(
                    "transform_target", "crop"
                ),  # mode for prior is always same as target
                scale=options.get(
                    "transform_scale", 1.0
                ),  # scale for prior is always same as target
            )
            prior_k = prior_k[0]
            prior_k[0] /= W
            prior_k[1] /= H
            traj_prior_Ks[i] = prior_k

    options["num_frames"] = T
    discretization = denoiser.discretization
    torch.cuda.empty_cache()

    seed_everything(seed)

    # Get Data
    input_indices = image_cond["input_indices"]
    input_imgs = imgs[input_indices]
    input_imgs_clip = imgs_clip[input_indices]
    input_c2ws = camera_cond["c2w"][input_indices]
    input_Ks = camera_cond["K"][input_indices]

    test_indices = [i for i in range(len(imgs)) if i not in input_indices]
    test_imgs = imgs[test_indices]
    test_imgs_clip = imgs_clip[test_indices]
    test_c2ws = camera_cond["c2w"][test_indices]
    test_Ks = camera_cond["K"][test_indices]

    if options.get("save_input", True):
        save_output(
            {"/image": input_imgs},
            save_path=os.path.join(save_path, "input"),
            video_save_fps=2,
        )

    if not use_traj_prior:
        chunk_strategy = options.get("chunk_strategy", "gt")

        (
            _,
            input_inds_per_chunk,
            input_sels_per_chunk,
            test_inds_per_chunk,
            test_sels_per_chunk,
        ) = chunk_input_and_test(
            T,
            input_c2ws,
            test_c2ws,
            input_indices,
            test_indices,
            options=options,
            task=task,
            chunk_strategy=chunk_strategy,
            gt_input_inds=list(range(input_c2ws.shape[0])),
        )
        print(
            f"One pass - chunking with `{chunk_strategy}` strategy: total "
            f"{len(input_inds_per_chunk)} forward(s) ..."
        )

        all_samples = {}
        all_test_inds = []
        for i, (
            chunk_input_inds,
            chunk_input_sels,
            chunk_test_inds,
            chunk_test_sels,
        ) in tqdm(
            enumerate(
                zip(
                    input_inds_per_chunk,
                    input_sels_per_chunk,
                    test_inds_per_chunk,
                    test_sels_per_chunk,
                )
            ),
            total=len(input_inds_per_chunk),
            leave=False,
        ):
            (
                curr_input_sels,
                curr_test_sels,
                curr_input_maps,
                curr_test_maps,
            ) = pad_indices(
                chunk_input_sels,
                chunk_test_sels,
                T=T,
                padding_mode=options.get("t_padding_mode", "last"),
            )
            curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
                assemble(
                    input=x[chunk_input_inds],
                    test=y[chunk_test_inds],
                    input_maps=curr_input_maps,
                    test_maps=curr_test_maps,
                )
                for x, y in zip(
                    [
                        torch.cat(
                            [
                                input_imgs,
                                get_k_from_dict(all_samples, "samples-rgb").to(
                                    input_imgs.device
                                ),
                            ],
                            dim=0,
                        ),
                        torch.cat(
                            [
                                input_imgs_clip,
                                get_k_from_dict(all_samples, "samples-rgb").to(
                                    input_imgs.device
                                ),
                            ],
                            dim=0,
                        ),
                        torch.cat([input_c2ws, test_c2ws[all_test_inds]], dim=0),
                        torch.cat([input_Ks, test_Ks[all_test_inds]], dim=0),
                    ],  # procedually append generated prior views to the input views
                    [test_imgs, test_imgs_clip, test_c2ws, test_Ks],
                )
            ]
            value_dict = get_value_dict(
                curr_imgs.to("cuda"),
                curr_imgs_clip.to("cuda"),
                curr_input_sels
                + [
                    sel
                    for (ind, sel) in zip(
                        np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
                        curr_test_sels,
                    )
                    if test_indices[ind] in image_cond["input_indices"]
                ],
                curr_c2ws,
                curr_Ks,
                curr_input_sels
                + [
                    sel
                    for (ind, sel) in zip(
                        np.array(chunk_test_inds)[curr_test_maps[curr_test_maps != -1]],
                        curr_test_sels,
                    )
                    if test_indices[ind] in camera_cond["input_indices"]
                ],
                all_c2ws=camera_cond["c2w"],
                camera_scale=options.get("camera_scale", 2.0),
            )
            samplers = create_samplers(
                options["guider_types"],
                discretization,
                [len(curr_imgs)],
                options["num_steps"],
                options["cfg_min"],
                abort_event=abort_event,
            )
            assert len(samplers) == 1
            samples = do_sample(
                model,
                ae,
                conditioner,
                denoiser,
                samplers[0],
                value_dict,
                H,
                W,
                C,
                F,
                T=len(curr_imgs),
                cfg=(
                    options["cfg"][0]
                    if isinstance(options["cfg"], (list, tuple))
                    else options["cfg"]
                ),
                **{k: options[k] for k in options if k not in ["cfg", "T"]},
            )
            samples = decode_output(
                samples, len(curr_imgs), chunk_test_sels
            )  # decode into dict
            if options.get("save_first_pass", False):
                save_output(
                    replace_or_include_input_for_dict(
                        samples,
                        chunk_test_sels,
                        curr_imgs,
                        curr_c2ws,
                        curr_Ks,
                    ),
                    save_path=os.path.join(save_path, "first-pass", f"forward_{i}"),
                    video_save_fps=2,
                )
            extend_dict(all_samples, samples)
            all_test_inds.extend(chunk_test_inds)
    else:
        assert traj_prior_c2ws is not None, (
            "`traj_prior_c2ws` should be set when using 2-pass sampling. One "
            "potential reason is that the amount of input frames is larger than "
            "T. Set `num_prior_frames` manually to overwrite the infered stats."
        )
        traj_prior_c2ws = torch.as_tensor(
            traj_prior_c2ws,
            device=input_c2ws.device,
            dtype=input_c2ws.dtype,
        )

        if traj_prior_Ks is None:
            traj_prior_Ks = test_Ks[:1].repeat_interleave(
                traj_prior_c2ws.shape[0], dim=0
            )

        traj_prior_imgs = imgs.new_zeros(traj_prior_c2ws.shape[0], *imgs.shape[1:])
        traj_prior_imgs_clip = imgs_clip.new_zeros(
            traj_prior_c2ws.shape[0], *imgs_clip.shape[1:]
        )

        # ---------------------------------- first pass ----------------------------------
        T_first_pass = T[0] if isinstance(T, (list, tuple)) else T
        T_second_pass = T[1] if isinstance(T, (list, tuple)) else T
        chunk_strategy_first_pass = options.get(
            "chunk_strategy_first_pass", "gt-nearest"
        )
        (
            _,
            input_inds_per_chunk,
            input_sels_per_chunk,
            prior_inds_per_chunk,
            prior_sels_per_chunk,
        ) = chunk_input_and_test(
            T_first_pass,
            input_c2ws,
            traj_prior_c2ws,
            input_indices,
            image_cond["prior_indices"],
            options=options,
            task=task,
            chunk_strategy=chunk_strategy_first_pass,
            gt_input_inds=list(range(input_c2ws.shape[0])),
        )
        print(
            f"Two passes (first) - chunking with `{chunk_strategy_first_pass}` strategy: total "
            f"{len(input_inds_per_chunk)} forward(s) ..."
        )

        all_samples = {}
        all_prior_inds = []
        for i, (
            chunk_input_inds,
            chunk_input_sels,
            chunk_prior_inds,
            chunk_prior_sels,
        ) in tqdm(
            enumerate(
                zip(
                    input_inds_per_chunk,
                    input_sels_per_chunk,
                    prior_inds_per_chunk,
                    prior_sels_per_chunk,
                )
            ),
            total=len(input_inds_per_chunk),
            leave=False,
        ):
            (
                curr_input_sels,
                curr_prior_sels,
                curr_input_maps,
                curr_prior_maps,
            ) = pad_indices(
                chunk_input_sels,
                chunk_prior_sels,
                T=T_first_pass,
                padding_mode=options.get("t_padding_mode", "last"),
            )
            curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
                assemble(
                    input=x[chunk_input_inds],
                    test=y[chunk_prior_inds],
                    input_maps=curr_input_maps,
                    test_maps=curr_prior_maps,
                )
                for x, y in zip(
                    [
                        torch.cat(
                            [
                                input_imgs,
                                get_k_from_dict(all_samples, "samples-rgb").to(
                                    input_imgs.device
                                ),
                            ],
                            dim=0,
                        ),
                        torch.cat(
                            [
                                input_imgs_clip,
                                get_k_from_dict(all_samples, "samples-rgb").to(
                                    input_imgs.device
                                ),
                            ],
                            dim=0,
                        ),
                        torch.cat([input_c2ws, traj_prior_c2ws[all_prior_inds]], dim=0),
                        torch.cat([input_Ks, traj_prior_Ks[all_prior_inds]], dim=0),
                    ],  # procedually append generated prior views to the input views
                    [
                        traj_prior_imgs,
                        traj_prior_imgs_clip,
                        traj_prior_c2ws,
                        traj_prior_Ks,
                    ],
                )
            ]
            value_dict = get_value_dict(
                curr_imgs.to("cuda"),
                curr_imgs_clip.to("cuda"),
                curr_input_sels,
                curr_c2ws,
                curr_Ks,
                list(range(T_first_pass)),
                all_c2ws=camera_cond["c2w"],
                camera_scale=options.get("camera_scale", 2.0),
            )
            samplers = create_samplers(
                options["guider_types"],
                discretization,
                [T_first_pass, T_second_pass],
                options["num_steps"],
                options["cfg_min"],
                abort_event=abort_event,
            )
            samples = do_sample(
                model,
                ae,
                conditioner,
                denoiser,
                (
                    samplers[1]
                    if len(samplers) > 1
                    and options.get("ltr_first_pass", False)
                    and chunk_strategy_first_pass != "gt"
                    and i > 0
                    else samplers[0]
                ),
                value_dict,
                H,
                W,
                C,
                F,
                cfg=(
                    options["cfg"][0]
                    if isinstance(options["cfg"], (list, tuple))
                    else options["cfg"]
                ),
                T=T_first_pass,
                global_pbar=first_pass_pbar,
                **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
            )
            if samples is None:
                return
            samples = decode_output(
                samples, T_first_pass, chunk_prior_sels
            )  # decode into dict
            extend_dict(all_samples, samples)
            all_prior_inds.extend(chunk_prior_inds)

        if options.get("save_first_pass", True):
            save_output(
                all_samples,
                save_path=os.path.join(save_path, "first-pass"),
                video_save_fps=5,
            )
            video_path_0 = os.path.join(save_path, "first-pass", "samples-rgb.mp4")
            yield video_path_0

        # ---------------------------------- second pass ----------------------------------
        prior_indices = image_cond["prior_indices"]
        assert (
            prior_indices is not None
        ), "`prior_frame_indices` needs to be set if using 2-pass sampling."
        prior_argsort = np.argsort(input_indices + prior_indices).tolist()
        prior_indices = np.array(input_indices + prior_indices)[prior_argsort].tolist()
        gt_input_inds = [prior_argsort.index(i) for i in range(input_c2ws.shape[0])]

        traj_prior_imgs = torch.cat(
            [input_imgs, get_k_from_dict(all_samples, "samples-rgb")], dim=0
        )[prior_argsort]
        traj_prior_imgs_clip = torch.cat(
            [
                input_imgs_clip,
                get_k_from_dict(all_samples, "samples-rgb"),
            ],
            dim=0,
        )[prior_argsort]
        traj_prior_c2ws = torch.cat([input_c2ws, traj_prior_c2ws], dim=0)[prior_argsort]
        traj_prior_Ks = torch.cat([input_Ks, traj_prior_Ks], dim=0)[prior_argsort]

        update_kv_for_dict(all_samples, "samples-rgb", traj_prior_imgs)
        update_kv_for_dict(all_samples, "samples-c2ws", traj_prior_c2ws)
        update_kv_for_dict(all_samples, "samples-intrinsics", traj_prior_Ks)

        chunk_strategy = options.get("chunk_strategy", "nearest")
        (
            _,
            prior_inds_per_chunk,
            prior_sels_per_chunk,
            test_inds_per_chunk,
            test_sels_per_chunk,
        ) = chunk_input_and_test(
            T_second_pass,
            traj_prior_c2ws,
            test_c2ws,
            prior_indices,
            test_indices,
            options=options,
            task=task,
            chunk_strategy=chunk_strategy,
            gt_input_inds=gt_input_inds,
        )
        print(
            f"Two passes (second) - chunking with `{chunk_strategy}` strategy: total "
            f"{len(prior_inds_per_chunk)} forward(s) ..."
        )

        all_samples = {}
        all_test_inds = []
        for i, (
            chunk_prior_inds,
            chunk_prior_sels,
            chunk_test_inds,
            chunk_test_sels,
        ) in tqdm(
            enumerate(
                zip(
                    prior_inds_per_chunk,
                    prior_sels_per_chunk,
                    test_inds_per_chunk,
                    test_sels_per_chunk,
                )
            ),
            total=len(prior_inds_per_chunk),
            leave=False,
        ):
            (
                curr_prior_sels,
                curr_test_sels,
                curr_prior_maps,
                curr_test_maps,
            ) = pad_indices(
                chunk_prior_sels,
                chunk_test_sels,
                T=T_second_pass,
                padding_mode="last",
            )
            curr_imgs, curr_imgs_clip, curr_c2ws, curr_Ks = [
                assemble(
                    input=x[chunk_prior_inds],
                    test=y[chunk_test_inds],
                    input_maps=curr_prior_maps,
                    test_maps=curr_test_maps,
                )
                for x, y in zip(
                    [
                        traj_prior_imgs,
                        traj_prior_imgs_clip,
                        traj_prior_c2ws,
                        traj_prior_Ks,
                    ],
                    [test_imgs, test_imgs_clip, test_c2ws, test_Ks],
                )
            ]
            value_dict = get_value_dict(
                curr_imgs.to("cuda"),
                curr_imgs_clip.to("cuda"),
                curr_prior_sels,
                curr_c2ws,
                curr_Ks,
                list(range(T_second_pass)),
                all_c2ws=camera_cond["c2w"],
                camera_scale=options.get("camera_scale", 2.0),
            )
            samples = do_sample(
                model,
                ae,
                conditioner,
                denoiser,
                samplers[1] if len(samplers) > 1 else samplers[0],
                value_dict,
                H,
                W,
                C,
                F,
                T=T_second_pass,
                cfg=(
                    options["cfg"][1]
                    if isinstance(options["cfg"], (list, tuple))
                    and len(options["cfg"]) > 1
                    else options["cfg"]
                ),
                global_pbar=second_pass_pbar,
                **{k: options[k] for k in options if k not in ["cfg", "T", "sampler"]},
            )
            if samples is None:
                return
            samples = decode_output(
                samples, T_second_pass, chunk_test_sels
            )  # decode into dict
            if options.get("save_second_pass", False):
                save_output(
                    replace_or_include_input_for_dict(
                        samples,
                        chunk_test_sels,
                        curr_imgs,
                        curr_c2ws,
                        curr_Ks,
                    ),
                    save_path=os.path.join(save_path, "second-pass", f"forward_{i}"),
                    video_save_fps=2,
                )
            extend_dict(all_samples, samples)
            all_test_inds.extend(chunk_test_inds)
        all_samples = {
            key: value[np.argsort(all_test_inds)] for key, value in all_samples.items()
        }
    save_output(
        replace_or_include_input_for_dict(
            all_samples,
            test_indices,
            imgs.clone(),
            camera_cond["c2w"].clone(),
            camera_cond["K"].clone(),
        )
        if options.get("replace_or_include_input", False)
        else all_samples,
        save_path=save_path,
        video_save_fps=options.get("video_save_fps", 2),
    )
    video_path_1 = os.path.join(save_path, "samples-rgb.mp4")
    yield video_path_1