Spaces:
Running
on
Zero
Running
on
Zero
import argparse | |
import cv2 | |
import numpy as np | |
import os | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from torchvision.transforms import Compose | |
from models.monoD.depth_anything.dpt import DPT_DINOv2 | |
from models.monoD.depth_anything.util.transform import ( | |
Resize, NormalizeImage, PrepareForNet | |
) | |
def build(config): | |
""" | |
Build the model from the config | |
NOTE: the config should contain the following | |
- encoder: the encoder type of the model | |
- load_from: the path to the pretrained model | |
""" | |
args = config | |
assert args.encoder in ['vits', 'vitb', 'vitl'] | |
if args.encoder == 'vits': | |
depth_anything = DPT_DINOv2(encoder='vits', features=64, | |
out_channels=[48, 96, 192, 384], | |
localhub=args.localhub).cuda() | |
elif args.encoder == 'vitb': | |
depth_anything = DPT_DINOv2(encoder='vitb', features=128, | |
out_channels=[96, 192, 384, 768], | |
localhub=args.localhub).cuda() | |
else: | |
depth_anything = DPT_DINOv2(encoder='vitl', features=256, | |
out_channels=[256, 512, 1024, 1024], | |
localhub=args.localhub).cuda() | |
depth_anything.load_state_dict(torch.load(args.load_from, | |
map_location='cpu'), strict=True) | |
total_params = sum(param.numel() for param in depth_anything.parameters()) | |
print('Total parameters: {:.2f}M'.format(total_params / 1e6)) | |
depth_anything.eval() | |
return depth_anything | |
class DepthAnything(nn.Module): | |
def __init__(self, args): | |
super(DepthAnything, self).__init__() | |
# build the chosen model | |
self.dpAny = build(args) | |
def infer(self, rgbs): | |
""" | |
Infer the depth map from the input RGB image | |
Args: | |
rgbs: the input RGB image B x 3 x H x W (Cuda Tensor) | |
Asserts: | |
the input should be a cuda tensor | |
""" | |
assert (rgbs.is_cuda)&(len(rgbs.shape) == 4) | |
T, C, H, W = rgbs.shape | |
# prepare the input | |
Resizer = Resize( | |
width=518, | |
height=518, | |
resize_target=False, | |
keep_aspect_ratio=True, | |
ensure_multiple_of=14, | |
resize_method='lower_bound', | |
image_interpolation_method=cv2.INTER_CUBIC, | |
) | |
#NOTE: step 1 Resize | |
width, height = Resizer.get_size( | |
rgbs.shape[2], rgbs.shape[3] | |
) | |
rgbs = F.interpolate( | |
rgbs, (int(height), int(width)), mode='bicubic', align_corners=False | |
) | |
#NOTE: step 2 NormalizeImage | |
mean_ = torch.tensor([0.485, 0.456, 0.406], | |
device=rgbs.device).view(1, 3, 1, 1) | |
std_ = torch.tensor([0.229, 0.224, 0.225], | |
device=rgbs.device).view(1, 3, 1, 1) | |
rgbs = (rgbs - mean_)/std_ | |
#NOTE: step 3 PrepareForNet | |
# get the depth map | |
disp = self.dpAny(rgbs) | |
disp = F.interpolate( | |
disp[:,None], (H, W), | |
mode='bilinear', align_corners=False | |
) | |
# clamping the farthest depth to 100x of the nearest | |
depth_map = disp | |
return depth_map | |