Spaces:
Running
on
Zero
Running
on
Zero
File size: 3,456 Bytes
a51c6d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 |
import argparse
import cv2
import numpy as np
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision.transforms import Compose
from models.monoD.depth_anything.dpt import DPT_DINOv2
from models.monoD.depth_anything.util.transform import (
Resize, NormalizeImage, PrepareForNet
)
def build(config):
"""
Build the model from the config
NOTE: the config should contain the following
- encoder: the encoder type of the model
- load_from: the path to the pretrained model
"""
args = config
assert args.encoder in ['vits', 'vitb', 'vitl']
if args.encoder == 'vits':
depth_anything = DPT_DINOv2(encoder='vits', features=64,
out_channels=[48, 96, 192, 384],
localhub=args.localhub).cuda()
elif args.encoder == 'vitb':
depth_anything = DPT_DINOv2(encoder='vitb', features=128,
out_channels=[96, 192, 384, 768],
localhub=args.localhub).cuda()
else:
depth_anything = DPT_DINOv2(encoder='vitl', features=256,
out_channels=[256, 512, 1024, 1024],
localhub=args.localhub).cuda()
depth_anything.load_state_dict(torch.load(args.load_from,
map_location='cpu'), strict=True)
total_params = sum(param.numel() for param in depth_anything.parameters())
print('Total parameters: {:.2f}M'.format(total_params / 1e6))
depth_anything.eval()
return depth_anything
class DepthAnything(nn.Module):
def __init__(self, args):
super(DepthAnything, self).__init__()
# build the chosen model
self.dpAny = build(args)
def infer(self, rgbs):
"""
Infer the depth map from the input RGB image
Args:
rgbs: the input RGB image B x 3 x H x W (Cuda Tensor)
Asserts:
the input should be a cuda tensor
"""
assert (rgbs.is_cuda)&(len(rgbs.shape) == 4)
T, C, H, W = rgbs.shape
# prepare the input
Resizer = Resize(
width=518,
height=518,
resize_target=False,
keep_aspect_ratio=True,
ensure_multiple_of=14,
resize_method='lower_bound',
image_interpolation_method=cv2.INTER_CUBIC,
)
#NOTE: step 1 Resize
width, height = Resizer.get_size(
rgbs.shape[2], rgbs.shape[3]
)
rgbs = F.interpolate(
rgbs, (int(height), int(width)), mode='bicubic', align_corners=False
)
#NOTE: step 2 NormalizeImage
mean_ = torch.tensor([0.485, 0.456, 0.406],
device=rgbs.device).view(1, 3, 1, 1)
std_ = torch.tensor([0.229, 0.224, 0.225],
device=rgbs.device).view(1, 3, 1, 1)
rgbs = (rgbs - mean_)/std_
#NOTE: step 3 PrepareForNet
# get the depth map
disp = self.dpAny(rgbs)
disp = F.interpolate(
disp[:,None], (H, W),
mode='bilinear', align_corners=False
)
# clamping the farthest depth to 100x of the nearest
depth_map = disp
return depth_map
|