Spaces:
Running
on
Zero
Running
on
Zero
from typing import * | |
from numbers import Number | |
from functools import partial | |
from pathlib import Path | |
import importlib | |
import warnings | |
import json | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.utils | |
import torch.utils.checkpoint | |
import torch.version | |
import utils3d | |
from huggingface_hub import hf_hub_download | |
from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask | |
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing | |
from ..utils.tools import timeit | |
class ResidualConvBlock(nn.Module): | |
def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): | |
super(ResidualConvBlock, self).__init__() | |
if out_channels is None: | |
out_channels = in_channels | |
if hidden_channels is None: | |
hidden_channels = in_channels | |
if activation =='relu': | |
activation_cls = lambda: nn.ReLU(inplace=True) | |
elif activation == 'leaky_relu': | |
activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) | |
elif activation =='silu': | |
activation_cls = lambda: nn.SiLU(inplace=True) | |
elif activation == 'elu': | |
activation_cls = lambda: nn.ELU(inplace=True) | |
else: | |
raise ValueError(f'Unsupported activation function: {activation}') | |
self.layers = nn.Sequential( | |
nn.GroupNorm(1, in_channels), | |
activation_cls(), | |
nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), | |
nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), | |
activation_cls(), | |
nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) | |
) | |
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() | |
def forward(self, x): | |
skip = self.skip_connection(x) | |
x = self.layers(x) | |
x = x + skip | |
return x | |
class Head(nn.Module): | |
def __init__( | |
self, | |
num_features: int, | |
dim_in: int, | |
dim_out: List[int], | |
dim_proj: int = 512, | |
dim_upsample: List[int] = [256, 128, 128], | |
dim_times_res_block_hidden: int = 1, | |
num_res_blocks: int = 1, | |
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', | |
last_res_blocks: int = 0, | |
last_conv_channels: int = 32, | |
last_conv_size: int = 1 | |
): | |
super().__init__() | |
self.projects = nn.ModuleList([ | |
nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) | |
]) | |
self.upsample_blocks = nn.ModuleList([ | |
nn.Sequential( | |
self._make_upsampler(in_ch + 2, out_ch), | |
*(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) | |
) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) | |
]) | |
self.output_block = nn.ModuleList([ | |
self._make_output_block( | |
dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm, | |
) for dim_out_ in dim_out | |
]) | |
def _make_upsampler(self, in_channels: int, out_channels: int): | |
upsampler = nn.Sequential( | |
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), | |
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') | |
) | |
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] | |
return upsampler | |
def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): | |
return nn.Sequential( | |
nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), | |
*(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), | |
nn.ReLU(inplace=True), | |
nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), | |
) | |
def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): | |
img_h, img_w = image.shape[-2:] | |
patch_h, patch_w = img_h // 14, img_w // 14 | |
# Process the hidden states | |
x = torch.stack([ | |
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) | |
for proj, (feat, clstoken) in zip(self.projects, hidden_states) | |
], dim=1).sum(dim=1) | |
# Upsample stage | |
# (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) | |
for i, block in enumerate(self.upsample_blocks): | |
# UV coordinates is for awareness of image aspect ratio | |
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) | |
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) | |
x = torch.cat([x, uv], dim=1) | |
for layer in block: | |
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) | |
# (patch_h * 8, patch_w * 8) -> (img_h, img_w) | |
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) | |
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) | |
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) | |
x = torch.cat([x, uv], dim=1) | |
if isinstance(self.output_block, nn.ModuleList): | |
output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block] | |
else: | |
output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) | |
return output | |
class MoGeModel(nn.Module): | |
image_mean: torch.Tensor | |
image_std: torch.Tensor | |
def __init__(self, | |
encoder: str = 'dinov2_vitb14', | |
intermediate_layers: Union[int, List[int]] = 4, | |
dim_proj: int = 512, | |
dim_upsample: List[int] = [256, 128, 128], | |
dim_times_res_block_hidden: int = 1, | |
num_res_blocks: int = 1, | |
remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', | |
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', | |
num_tokens_range: Tuple[Number, Number] = [1200, 2500], | |
last_res_blocks: int = 0, | |
last_conv_channels: int = 32, | |
last_conv_size: int = 1, | |
mask_threshold: float = 0.5, | |
**deprecated_kwargs | |
): | |
super(MoGeModel, self).__init__() | |
if deprecated_kwargs: | |
# Process legacy arguments | |
if 'trained_area_range' in deprecated_kwargs: | |
num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2] | |
del deprecated_kwargs['trained_area_range'] | |
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") | |
self.encoder = encoder | |
self.remap_output = remap_output | |
self.intermediate_layers = intermediate_layers | |
self.num_tokens_range = num_tokens_range | |
self.mask_threshold = mask_threshold | |
# NOTE: We have copied the DINOv2 code in torchhub to this repository. | |
# Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. | |
hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) | |
self.backbone = hub_loader(pretrained=False) | |
dim_feature = self.backbone.blocks[0].attn.qkv.in_features | |
self.head = Head( | |
num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), | |
dim_in=dim_feature, | |
dim_out=[3, 1], | |
dim_proj=dim_proj, | |
dim_upsample=dim_upsample, | |
dim_times_res_block_hidden=dim_times_res_block_hidden, | |
num_res_blocks=num_res_blocks, | |
res_block_norm=res_block_norm, | |
last_res_blocks=last_res_blocks, | |
last_conv_channels=last_conv_channels, | |
last_conv_size=last_conv_size | |
) | |
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) | |
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) | |
self.register_buffer("image_mean", image_mean) | |
self.register_buffer("image_std", image_std) | |
def device(self) -> torch.device: | |
return next(self.parameters()).device | |
def dtype(self) -> torch.dtype: | |
return next(self.parameters()).dtype | |
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': | |
""" | |
Load a model from a checkpoint file. | |
### Parameters: | |
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id. | |
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. | |
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. | |
### Returns: | |
- A new instance of `MoGe` with the parameters loaded from the checkpoint. | |
""" | |
if Path(pretrained_model_name_or_path).exists(): | |
checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) | |
else: | |
cached_checkpoint_path = hf_hub_download( | |
repo_id=pretrained_model_name_or_path, | |
repo_type="model", | |
filename="model.pt", | |
**hf_kwargs | |
) | |
checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) | |
model_config = checkpoint['model_config'] | |
if model_kwargs is not None: | |
model_config.update(model_kwargs) | |
model = cls(**model_config) | |
model.load_state_dict(checkpoint['model']) | |
return model | |
def init_weights(self): | |
"Load the backbone with pretrained dinov2 weights from torch hub" | |
state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() | |
self.backbone.load_state_dict(state_dict) | |
def enable_gradient_checkpointing(self): | |
for i in range(len(self.backbone.blocks)): | |
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) | |
def _remap_points(self, points: torch.Tensor) -> torch.Tensor: | |
if self.remap_output == 'linear': | |
pass | |
elif self.remap_output =='sinh': | |
points = torch.sinh(points) | |
elif self.remap_output == 'exp': | |
xy, z = points.split([2, 1], dim=-1) | |
z = torch.exp(z) | |
points = torch.cat([xy * z, z], dim=-1) | |
elif self.remap_output =='sinh_exp': | |
xy, z = points.split([2, 1], dim=-1) | |
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) | |
else: | |
raise ValueError(f"Invalid remap output type: {self.remap_output}") | |
return points | |
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]: | |
original_height, original_width = image.shape[-2:] | |
# Resize to expected resolution defined by num_tokens | |
resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5 | |
resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor) | |
image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True) | |
# Apply image transformation for DINOv2 | |
image = (image - self.image_mean) / self.image_std | |
image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True) | |
# Get intermediate layers from the backbone | |
features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) | |
# Predict points (and mask) | |
output = self.head(features, image) | |
points, mask = output | |
# Make sure fp32 precision for output | |
with torch.autocast(device_type=image.device.type, dtype=torch.float32): | |
# Resize to original resolution | |
points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False) | |
mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False) | |
# Post-process points and mask | |
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) | |
points = self._remap_points(points) # slightly improves the performance in case of very large output values | |
return_dict = {'points': points, 'mask': mask} | |
return return_dict | |
def infer( | |
self, | |
image: torch.Tensor, | |
fov_x: Union[Number, torch.Tensor] = None, | |
resolution_level: int = 9, | |
num_tokens: int = None, | |
apply_mask: bool = True, | |
force_projection: bool = True, | |
use_fp16: bool = True, | |
) -> Dict[str, torch.Tensor]: | |
""" | |
User-friendly inference function | |
### Parameters | |
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\ | |
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None | |
- `resolution_level`: An integer [0-9] for the resolution level for inference. | |
The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size. | |
`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details. | |
- `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`. | |
`resolution_level` will be ignored if `num_tokens` is provided. Default: None | |
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True | |
- `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True | |
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True | |
### Returns | |
A dictionary containing the following keys: | |
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). | |
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. | |
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. | |
""" | |
if image.dim() == 3: | |
omit_batch_dim = True | |
image = image.unsqueeze(0) | |
else: | |
omit_batch_dim = False | |
image = image.to(dtype=self.dtype, device=self.device) | |
original_height, original_width = image.shape[-2:] | |
aspect_ratio = original_width / original_height | |
if num_tokens is None: | |
min_tokens, max_tokens = self.num_tokens_range | |
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens)) | |
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16): | |
output = self.forward(image, num_tokens) | |
points, mask = output['points'], output['mask'] | |
# Always process the output in fp32 precision | |
with torch.autocast(device_type=self.device.type, dtype=torch.float32): | |
points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x]) | |
mask_binary = mask > self.mask_threshold | |
# Get camera-space point map. (Focal here is the focal length relative to half the image diagonal) | |
if fov_x is None: | |
focal, shift = recover_focal_shift(points, mask_binary) | |
else: | |
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2)) | |
if focal.ndim == 0: | |
focal = focal[None].expand(points.shape[0]) | |
_, shift = recover_focal_shift(points, mask_binary, focal=focal) | |
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio | |
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 | |
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5) | |
depth = points[..., 2] + shift[..., None, None] | |
# If projection constraint is forced, recompute the point map using the actual depth map | |
if force_projection: | |
points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics) | |
else: | |
points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :] | |
# Apply mask if needed | |
if apply_mask: | |
points = torch.where(mask_binary[..., None], points, torch.inf) | |
depth = torch.where(mask_binary, depth, torch.inf) | |
return_dict = { | |
'points': points, | |
'intrinsics': intrinsics, | |
'depth': depth, | |
'mask': mask_binary, | |
"mask_prob": mask, | |
} | |
if omit_batch_dim: | |
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()} | |
return return_dict |