xiaoyuxi
gradio_app
a51c6d2
import os
from pathlib import Path
import json
import time
import random
from typing import *
import traceback
import itertools
from numbers import Number
import io
import numpy as np
import cv2
from PIL import Image
import torch
import torchvision.transforms.v2.functional as TF
import utils3d
from tqdm import tqdm
from ..utils import pipeline
from ..utils.io import *
from ..utils.geometry_numpy import mask_aware_nearest_resize_numpy, harmonic_mean_numpy, norm3d, depth_occlusion_edge_numpy, depth_of_field
class TrainDataLoaderPipeline:
def __init__(self, config: dict, batch_size: int, num_load_workers: int = 4, num_process_workers: int = 8, buffer_size: int = 8):
self.config = config
self.batch_size = batch_size
self.clamp_max_depth = config['clamp_max_depth']
self.fov_range_absolute = config.get('fov_range_absolute', 0.0)
self.fov_range_relative = config.get('fov_range_relative', 0.0)
self.center_augmentation = config.get('center_augmentation', 0.0)
self.image_augmentation = config.get('image_augmentation', [])
self.depth_interpolation = config.get('depth_interpolation', 'bilinear')
if 'image_sizes' in config:
self.image_size_strategy = 'fixed'
self.image_sizes = config['image_sizes']
elif 'aspect_ratio_range' in config and 'area_range' in config:
self.image_size_strategy = 'aspect_area'
self.aspect_ratio_range = config['aspect_ratio_range']
self.area_range = config['area_range']
else:
raise ValueError('Invalid image size configuration')
# Load datasets
self.datasets = {}
for dataset in tqdm(config['datasets'], desc='Loading datasets'):
name = dataset['name']
content = Path(dataset['path'], dataset.get('index', '.index.txt')).joinpath().read_text()
filenames = content.splitlines()
self.datasets[name] = {
**dataset,
'path': dataset['path'],
'filenames': filenames,
}
self.dataset_names = [dataset['name'] for dataset in config['datasets']]
self.dataset_weights = [dataset['weight'] for dataset in config['datasets']]
# Build pipeline
self.pipeline = pipeline.Sequential([
self._sample_batch,
pipeline.Unbatch(),
pipeline.Parallel([self._load_instance] * num_load_workers),
pipeline.Parallel([self._process_instance] * num_process_workers),
pipeline.Batch(self.batch_size),
self._collate_batch,
pipeline.Buffer(buffer_size),
])
self.invalid_instance = {
'intrinsics': np.array([[1.0, 0.0, 0.5], [0.0, 1.0, 0.5], [0.0, 0.0, 1.0]], dtype=np.float32),
'image': np.zeros((256, 256, 3), dtype=np.uint8),
'depth': np.ones((256, 256), dtype=np.float32),
'depth_mask': np.ones((256, 256), dtype=bool),
'depth_mask_inf': np.zeros((256, 256), dtype=bool),
'label_type': 'invalid',
}
def _sample_batch(self):
batch_id = 0
last_area = None
while True:
# Depending on the sample strategy, choose a dataset and a filename
batch_id += 1
batch = []
# Sample instances
for _ in range(self.batch_size):
dataset_name = random.choices(self.dataset_names, weights=self.dataset_weights)[0]
filename = random.choice(self.datasets[dataset_name]['filenames'])
path = Path(self.datasets[dataset_name]['path'], filename)
instance = {
'batch_id': batch_id,
'seed': random.randint(0, 2 ** 32 - 1),
'dataset': dataset_name,
'filename': filename,
'path': path,
'label_type': self.datasets[dataset_name]['label_type'],
}
batch.append(instance)
# Decide the image size for this batch
if self.image_size_strategy == 'fixed':
width, height = random.choice(self.config['image_sizes'])
elif self.image_size_strategy == 'aspect_area':
area = random.uniform(*self.area_range)
aspect_ratio_ranges = [self.datasets[instance['dataset']].get('aspect_ratio_range', self.aspect_ratio_range) for instance in batch]
aspect_ratio_range = (min(r[0] for r in aspect_ratio_ranges), max(r[1] for r in aspect_ratio_ranges))
aspect_ratio = random.uniform(*aspect_ratio_range)
width, height = int((area * aspect_ratio) ** 0.5), int((area / aspect_ratio) ** 0.5)
else:
raise ValueError('Invalid image size strategy')
for instance in batch:
instance['width'], instance['height'] = width, height
yield batch
def _load_instance(self, instance: dict):
try:
image = read_image(Path(instance['path'], 'image.jpg'))
depth, _ = read_depth(Path(instance['path'], self.datasets[instance['dataset']].get('depth', 'depth.png')))
meta = read_meta(Path(instance['path'], 'meta.json'))
intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
depth_mask = np.isfinite(depth)
depth_mask_inf = np.isinf(depth)
depth = np.nan_to_num(depth, nan=1, posinf=1, neginf=1)
data = {
'image': image,
'depth': depth,
'depth_mask': depth_mask,
'depth_mask_inf': depth_mask_inf,
'intrinsics': intrinsics
}
instance.update({
**data,
})
except Exception as e:
print(f"Failed to load instance {instance['dataset']}/{instance['filename']} because of exception:", e)
instance.update(self.invalid_instance)
return instance
def _process_instance(self, instance: Dict[str, Union[np.ndarray, str, float, bool]]):
image, depth, depth_mask, depth_mask_inf, intrinsics, label_type = instance['image'], instance['depth'], instance['depth_mask'], instance['depth_mask_inf'], instance['intrinsics'], instance['label_type']
depth_unit = self.datasets[instance['dataset']].get('depth_unit', None)
raw_height, raw_width = image.shape[:2]
raw_horizontal, raw_vertical = abs(1.0 / intrinsics[0, 0]), abs(1.0 / intrinsics[1, 1])
raw_fov_x, raw_fov_y = utils3d.numpy.intrinsics_to_fov(intrinsics)
raw_pixel_w, raw_pixel_h = raw_horizontal / raw_width, raw_vertical / raw_height
tgt_width, tgt_height = instance['width'], instance['height']
tgt_aspect = tgt_width / tgt_height
rng = np.random.default_rng(instance['seed'])
# 1. set target fov
center_augmentation = self.datasets[instance['dataset']].get('center_augmentation', self.center_augmentation)
fov_range_absolute_min, fov_range_absolute_max = self.datasets[instance['dataset']].get('fov_range_absolute', self.fov_range_absolute)
fov_range_relative_min, fov_range_relative_max = self.datasets[instance['dataset']].get('fov_range_relative', self.fov_range_relative)
tgt_fov_x_min = min(fov_range_relative_min * raw_fov_x, fov_range_relative_min * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
tgt_fov_x_max = min(fov_range_relative_max * raw_fov_x, fov_range_relative_max * utils3d.focal_to_fov(utils3d.fov_to_focal(raw_fov_y) / tgt_aspect))
tgt_fov_x_min, tgt_fov_x_max = max(np.deg2rad(fov_range_absolute_min), tgt_fov_x_min), min(np.deg2rad(fov_range_absolute_max), tgt_fov_x_max)
tgt_fov_x = rng.uniform(min(tgt_fov_x_min, tgt_fov_x_max), tgt_fov_x_max)
tgt_fov_y = utils3d.focal_to_fov(utils3d.numpy.fov_to_focal(tgt_fov_x) * tgt_aspect)
# 2. set target image center (principal point) and the corresponding z-direction in raw camera space
center_dtheta = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_x - tgt_fov_x)
center_dphi = center_augmentation * rng.uniform(-0.5, 0.5) * (raw_fov_y - tgt_fov_y)
cu, cv = 0.5 + 0.5 * np.tan(center_dtheta) / np.tan(raw_fov_x / 2), 0.5 + 0.5 * np.tan(center_dphi) / np.tan(raw_fov_y / 2)
direction = utils3d.unproject_cv(np.array([[cu, cv]], dtype=np.float32), np.array([1.0], dtype=np.float32), intrinsics=intrinsics)[0]
# 3. obtain the rotation matrix for homography warping
R = utils3d.rotation_matrix_from_vectors(direction, np.array([0, 0, 1], dtype=np.float32))
# 4. shrink the target view to fit into the warped image
corners = np.array([[0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32)
corners = np.concatenate([corners, np.ones((4, 1), dtype=np.float32)], axis=1) @ (np.linalg.inv(intrinsics).T @ R.T) # corners in viewport's camera plane
corners = corners[:, :2] / corners[:, 2:3]
tgt_horizontal, tgt_vertical = np.tan(tgt_fov_x / 2) * 2, np.tan(tgt_fov_y / 2) * 2
warp_horizontal, warp_vertical = float('inf'), float('inf')
for i in range(4):
intersection, _ = utils3d.numpy.ray_intersection(
np.array([0., 0.]), np.array([[tgt_aspect, 1.0], [tgt_aspect, -1.0]]),
corners[i - 1], corners[i] - corners[i - 1],
)
warp_horizontal, warp_vertical = min(warp_horizontal, 2 * np.abs(intersection[:, 0]).min()), min(warp_vertical, 2 * np.abs(intersection[:, 1]).min())
tgt_horizontal, tgt_vertical = min(tgt_horizontal, warp_horizontal), min(tgt_vertical, warp_vertical)
# 5. obtain the target intrinsics
fx, fy = 1 / tgt_horizontal, 1 / tgt_vertical
tgt_intrinsics = utils3d.numpy.intrinsics_from_focal_center(fx, fy, 0.5, 0.5).astype(np.float32)
# 6. do homogeneous transformation
# 6.1 The image and depth are resized first to approximately the same pixel size as the target image with PIL's antialiasing resampling
tgt_pixel_w, tgt_pixel_h = tgt_horizontal / tgt_width, tgt_vertical / tgt_height # (should be exactly the same for x and y axes)
rescaled_w, rescaled_h = int(raw_width * raw_pixel_w / tgt_pixel_w), int(raw_height * raw_pixel_h / tgt_pixel_h)
image = np.array(Image.fromarray(image).resize((rescaled_w, rescaled_h), Image.Resampling.LANCZOS))
edge_mask = depth_occlusion_edge_numpy(depth, mask=depth_mask, thickness=2, tol=0.01)
_, depth_mask_nearest, resize_index = mask_aware_nearest_resize_numpy(None, depth_mask, (rescaled_w, rescaled_h), return_index=True)
depth_nearest = depth[resize_index]
distance_nearest = norm3d(utils3d.numpy.depth_to_points(depth_nearest, intrinsics=intrinsics))
edge_mask = edge_mask[resize_index]
if self.depth_interpolation == 'bilinear':
depth_mask_bilinear = cv2.resize(depth_mask.astype(np.float32), (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
depth_bilinear = 1 / cv2.resize(1 / depth, (rescaled_w, rescaled_h), interpolation=cv2.INTER_LINEAR)
distance_bilinear = norm3d(utils3d.numpy.depth_to_points(depth_bilinear, intrinsics=intrinsics))
depth_mask_inf = cv2.resize(depth_mask_inf.astype(np.uint8), (rescaled_w, rescaled_h), interpolation=cv2.INTER_NEAREST) > 0
# 6.2 calculate homography warping
transform = intrinsics @ np.linalg.inv(R) @ np.linalg.inv(tgt_intrinsics)
uv_tgt = utils3d.numpy.image_uv(width=tgt_width, height=tgt_height)
pts = np.concatenate([uv_tgt, np.ones((tgt_height, tgt_width, 1), dtype=np.float32)], axis=-1) @ transform.T
uv_remap = pts[:, :, :2] / (pts[:, :, 2:3] + 1e-12)
pixel_remap = utils3d.numpy.uv_to_pixel(uv_remap, width=rescaled_w, height=rescaled_h).astype(np.float32)
tgt_image = cv2.remap(image, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LANCZOS4)
tgt_ray_length = norm3d(utils3d.numpy.unproject_cv(uv_tgt, np.ones_like(uv_tgt[:, :, 0]), intrinsics=tgt_intrinsics))
tgt_depth_mask_nearest = cv2.remap(depth_mask_nearest.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
tgt_depth_nearest = cv2.remap(distance_nearest, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) / tgt_ray_length
tgt_edge_mask = cv2.remap(edge_mask.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
if self.depth_interpolation == 'bilinear':
tgt_depth_mask_bilinear = cv2.remap(depth_mask_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR)
tgt_depth_bilinear = cv2.remap(distance_bilinear, pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_LINEAR) / tgt_ray_length
tgt_depth = np.where((tgt_depth_mask_bilinear == 1) & ~tgt_edge_mask, tgt_depth_bilinear, tgt_depth_nearest)
else:
tgt_depth = tgt_depth_nearest
tgt_depth_mask = tgt_depth_mask_nearest
tgt_depth_mask_inf = cv2.remap(depth_mask_inf.astype(np.uint8), pixel_remap[:, :, 0], pixel_remap[:, :, 1], cv2.INTER_NEAREST) > 0
# always make sure that mask is not empty
if tgt_depth_mask.sum() / tgt_depth_mask.size < 0.001:
tgt_depth_mask = np.ones_like(tgt_depth_mask)
tgt_depth = np.ones_like(tgt_depth)
instance['label_type'] = 'invalid'
# Flip augmentation
if rng.choice([True, False]):
tgt_image = np.flip(tgt_image, axis=1).copy()
tgt_depth = np.flip(tgt_depth, axis=1).copy()
tgt_depth_mask = np.flip(tgt_depth_mask, axis=1).copy()
tgt_depth_mask_inf = np.flip(tgt_depth_mask_inf, axis=1).copy()
# Color augmentation
image_augmentation = self.datasets[instance['dataset']].get('image_augmentation', self.image_augmentation)
if 'jittering' in image_augmentation:
tgt_image = torch.from_numpy(tgt_image).permute(2, 0, 1)
tgt_image = TF.adjust_brightness(tgt_image, rng.uniform(0.7, 1.3))
tgt_image = TF.adjust_contrast(tgt_image, rng.uniform(0.7, 1.3))
tgt_image = TF.adjust_saturation(tgt_image, rng.uniform(0.7, 1.3))
tgt_image = TF.adjust_hue(tgt_image, rng.uniform(-0.1, 0.1))
tgt_image = TF.adjust_gamma(tgt_image, rng.uniform(0.7, 1.3))
tgt_image = tgt_image.permute(1, 2, 0).numpy()
if 'dof' in image_augmentation:
if rng.uniform() < 0.5:
dof_strength = rng.integers(12)
tgt_disp = np.where(tgt_depth_mask_inf, 0, 1 / tgt_depth)
disp_min, disp_max = tgt_disp[tgt_depth_mask].min(), tgt_disp[tgt_depth_mask].max()
tgt_disp = cv2.inpaint(tgt_disp, (~tgt_depth_mask & ~tgt_depth_mask_inf).astype(np.uint8), 3, cv2.INPAINT_TELEA).clip(disp_min, disp_max)
dof_focus = rng.uniform(disp_min, disp_max)
tgt_image = depth_of_field(tgt_image, tgt_disp, dof_focus, dof_strength)
if 'shot_noise' in image_augmentation:
if rng.uniform() < 0.5:
k = np.exp(rng.uniform(np.log(100), np.log(10000))) / 255
tgt_image = (rng.poisson(tgt_image * k) / k).clip(0, 255).astype(np.uint8)
if 'jpeg_loss' in image_augmentation:
if rng.uniform() < 0.5:
tgt_image = cv2.imdecode(cv2.imencode('.jpg', tgt_image, [cv2.IMWRITE_JPEG_QUALITY, rng.integers(20, 100)])[1], cv2.IMREAD_COLOR)
if 'blurring' in image_augmentation:
if rng.uniform() < 0.5:
ratio = rng.uniform(0.25, 1)
tgt_image = cv2.resize(cv2.resize(tgt_image, (int(tgt_width * ratio), int(tgt_height * ratio)), interpolation=cv2.INTER_AREA), (tgt_width, tgt_height), interpolation=rng.choice([cv2.INTER_LINEAR_EXACT, cv2.INTER_CUBIC, cv2.INTER_LANCZOS4]))
# convert depth to metric if necessary
if depth_unit is not None:
tgt_depth *= depth_unit
instance['is_metric'] = True
else:
instance['is_metric'] = False
# clamp depth maximum values
max_depth = np.nanquantile(np.where(tgt_depth_mask, tgt_depth, np.nan), 0.01) * self.clamp_max_depth
tgt_depth = np.clip(tgt_depth, 0, max_depth)
tgt_depth = np.nan_to_num(tgt_depth, nan=1.0)
if self.datasets[instance['dataset']].get('finite_depth_mask', None) == "only_known":
tgt_depth_mask_fin = tgt_depth_mask
else:
tgt_depth_mask_fin = ~tgt_depth_mask_inf
instance.update({
'image': torch.from_numpy(tgt_image.astype(np.float32) / 255.0).permute(2, 0, 1),
'depth': torch.from_numpy(tgt_depth).float(),
'depth_mask': torch.from_numpy(tgt_depth_mask).bool(),
'depth_mask_fin': torch.from_numpy(tgt_depth_mask_fin).bool(),
'depth_mask_inf': torch.from_numpy(tgt_depth_mask_inf).bool(),
'intrinsics': torch.from_numpy(tgt_intrinsics).float(),
})
return instance
def _collate_batch(self, instances: List[Dict[str, Any]]):
batch = {k: torch.stack([instance[k] for instance in instances], dim=0) for k in ['image', 'depth', 'depth_mask', 'depth_mask_fin', 'depth_mask_inf', 'intrinsics']}
batch = {
'label_type': [instance['label_type'] for instance in instances],
'is_metric': [instance['is_metric'] for instance in instances],
'info': [{'dataset': instance['dataset'], 'filename': instance['filename']} for instance in instances],
**batch,
}
return batch
def get(self) -> Dict[str, Union[torch.Tensor, str]]:
return self.pipeline.get()
def start(self):
self.pipeline.start()
def stop(self):
self.pipeline.stop()
def __enter__(self):
self.start()
return self
def __exit__(self, exc_type, exc_value, traceback):
self.pipeline.terminate()
self.pipeline.join()
return False