xiaoyuxi
gradio_app
a51c6d2
from typing import *
from numbers import Number
from functools import partial
from pathlib import Path
import importlib
import warnings
import json
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils
import torch.utils.checkpoint
import torch.version
import utils3d
from huggingface_hub import hf_hub_download
from ..utils.geometry_torch import normalized_view_plane_uv, recover_focal_shift, gaussian_blur_2d, dilate_with_mask
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
from ..utils.tools import timeit
class ResidualConvBlock(nn.Module):
def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
super(ResidualConvBlock, self).__init__()
if out_channels is None:
out_channels = in_channels
if hidden_channels is None:
hidden_channels = in_channels
if activation =='relu':
activation_cls = lambda: nn.ReLU(inplace=True)
elif activation == 'leaky_relu':
activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
elif activation =='silu':
activation_cls = lambda: nn.SiLU(inplace=True)
elif activation == 'elu':
activation_cls = lambda: nn.ELU(inplace=True)
else:
raise ValueError(f'Unsupported activation function: {activation}')
self.layers = nn.Sequential(
nn.GroupNorm(1, in_channels),
activation_cls(),
nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
activation_cls(),
nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
)
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
def forward(self, x):
skip = self.skip_connection(x)
x = self.layers(x)
x = x + skip
return x
class Head(nn.Module):
def __init__(
self,
num_features: int,
dim_in: int,
dim_out: List[int],
dim_proj: int = 512,
dim_upsample: List[int] = [256, 128, 128],
dim_times_res_block_hidden: int = 1,
num_res_blocks: int = 1,
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
last_res_blocks: int = 0,
last_conv_channels: int = 32,
last_conv_size: int = 1
):
super().__init__()
self.projects = nn.ModuleList([
nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
])
self.upsample_blocks = nn.ModuleList([
nn.Sequential(
self._make_upsampler(in_ch + 2, out_ch),
*(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
])
self.output_block = nn.ModuleList([
self._make_output_block(
dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
) for dim_out_ in dim_out
])
def _make_upsampler(self, in_channels: int, out_channels: int):
upsampler = nn.Sequential(
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
)
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
return upsampler
def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
return nn.Sequential(
nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
*(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
nn.ReLU(inplace=True),
nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
)
def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
img_h, img_w = image.shape[-2:]
patch_h, patch_w = img_h // 14, img_w // 14
# Process the hidden states
x = torch.stack([
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
for proj, (feat, clstoken) in zip(self.projects, hidden_states)
], dim=1).sum(dim=1)
# Upsample stage
# (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
for i, block in enumerate(self.upsample_blocks):
# UV coordinates is for awareness of image aspect ratio
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
x = torch.cat([x, uv], dim=1)
for layer in block:
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
# (patch_h * 8, patch_w * 8) -> (img_h, img_w)
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
uv = normalized_view_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
x = torch.cat([x, uv], dim=1)
if isinstance(self.output_block, nn.ModuleList):
output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
else:
output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
return output
class MoGeModel(nn.Module):
image_mean: torch.Tensor
image_std: torch.Tensor
def __init__(self,
encoder: str = 'dinov2_vitb14',
intermediate_layers: Union[int, List[int]] = 4,
dim_proj: int = 512,
dim_upsample: List[int] = [256, 128, 128],
dim_times_res_block_hidden: int = 1,
num_res_blocks: int = 1,
remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
num_tokens_range: Tuple[Number, Number] = [1200, 2500],
last_res_blocks: int = 0,
last_conv_channels: int = 32,
last_conv_size: int = 1,
mask_threshold: float = 0.5,
**deprecated_kwargs
):
super(MoGeModel, self).__init__()
if deprecated_kwargs:
# Process legacy arguments
if 'trained_area_range' in deprecated_kwargs:
num_tokens_range = [deprecated_kwargs['trained_area_range'][0] // 14 ** 2, deprecated_kwargs['trained_area_range'][1] // 14 ** 2]
del deprecated_kwargs['trained_area_range']
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
self.encoder = encoder
self.remap_output = remap_output
self.intermediate_layers = intermediate_layers
self.num_tokens_range = num_tokens_range
self.mask_threshold = mask_threshold
# NOTE: We have copied the DINOv2 code in torchhub to this repository.
# Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
self.backbone = hub_loader(pretrained=False)
dim_feature = self.backbone.blocks[0].attn.qkv.in_features
self.head = Head(
num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
dim_in=dim_feature,
dim_out=[3, 1],
dim_proj=dim_proj,
dim_upsample=dim_upsample,
dim_times_res_block_hidden=dim_times_res_block_hidden,
num_res_blocks=num_res_blocks,
res_block_norm=res_block_norm,
last_res_blocks=last_res_blocks,
last_conv_channels=last_conv_channels,
last_conv_size=last_conv_size
)
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
self.register_buffer("image_mean", image_mean)
self.register_buffer("image_std", image_std)
@property
def device(self) -> torch.device:
return next(self.parameters()).device
@property
def dtype(self) -> torch.dtype:
return next(self.parameters()).dtype
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
"""
Load a model from a checkpoint file.
### Parameters:
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
### Returns:
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
"""
if Path(pretrained_model_name_or_path).exists():
checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
else:
cached_checkpoint_path = hf_hub_download(
repo_id=pretrained_model_name_or_path,
repo_type="model",
filename="model.pt",
**hf_kwargs
)
checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
model_config = checkpoint['model_config']
if model_kwargs is not None:
model_config.update(model_kwargs)
model = cls(**model_config)
model.load_state_dict(checkpoint['model'])
return model
def init_weights(self):
"Load the backbone with pretrained dinov2 weights from torch hub"
state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
self.backbone.load_state_dict(state_dict)
def enable_gradient_checkpointing(self):
for i in range(len(self.backbone.blocks)):
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
def _remap_points(self, points: torch.Tensor) -> torch.Tensor:
if self.remap_output == 'linear':
pass
elif self.remap_output =='sinh':
points = torch.sinh(points)
elif self.remap_output == 'exp':
xy, z = points.split([2, 1], dim=-1)
z = torch.exp(z)
points = torch.cat([xy * z, z], dim=-1)
elif self.remap_output =='sinh_exp':
xy, z = points.split([2, 1], dim=-1)
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
else:
raise ValueError(f"Invalid remap output type: {self.remap_output}")
return points
def forward(self, image: torch.Tensor, num_tokens: int) -> Dict[str, torch.Tensor]:
original_height, original_width = image.shape[-2:]
# Resize to expected resolution defined by num_tokens
resize_factor = ((num_tokens * 14 ** 2) / (original_height * original_width)) ** 0.5
resized_width, resized_height = int(original_width * resize_factor), int(original_height * resize_factor)
image = F.interpolate(image, (resized_height, resized_width), mode="bicubic", align_corners=False, antialias=True)
# Apply image transformation for DINOv2
image = (image - self.image_mean) / self.image_std
image_14 = F.interpolate(image, (resized_height // 14 * 14, resized_width // 14 * 14), mode="bilinear", align_corners=False, antialias=True)
# Get intermediate layers from the backbone
features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
# Predict points (and mask)
output = self.head(features, image)
points, mask = output
# Make sure fp32 precision for output
with torch.autocast(device_type=image.device.type, dtype=torch.float32):
# Resize to original resolution
points = F.interpolate(points, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
mask = F.interpolate(mask, (original_height, original_width), mode='bilinear', align_corners=False, antialias=False)
# Post-process points and mask
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
points = self._remap_points(points) # slightly improves the performance in case of very large output values
return_dict = {'points': points, 'mask': mask}
return return_dict
@torch.inference_mode()
def infer(
self,
image: torch.Tensor,
fov_x: Union[Number, torch.Tensor] = None,
resolution_level: int = 9,
num_tokens: int = None,
apply_mask: bool = True,
force_projection: bool = True,
use_fp16: bool = True,
) -> Dict[str, torch.Tensor]:
"""
User-friendly inference function
### Parameters
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)\
- `fov_x`: the horizontal camera FoV in degrees. If None, it will be inferred from the predicted point map. Default: None
- `resolution_level`: An integer [0-9] for the resolution level for inference.
The higher, the finer details will be captured, but slower. Defaults to 9. Note that it is irrelevant to the output size, which is always the same as the input size.
`resolution_level` actually controls `num_tokens`. See `num_tokens` for more details.
- `num_tokens`: number of tokens used for inference. A integer in the (suggested) range of `[1200, 2500]`.
`resolution_level` will be ignored if `num_tokens` is provided. Default: None
- `apply_mask`: if True, the output point map will be masked using the predicted mask. Default: True
- `force_projection`: if True, the output point map will be recomputed to match the projection constraint. Default: True
- `use_fp16`: if True, use mixed precision to speed up inference. Default: True
### Returns
A dictionary containing the following keys:
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
"""
if image.dim() == 3:
omit_batch_dim = True
image = image.unsqueeze(0)
else:
omit_batch_dim = False
image = image.to(dtype=self.dtype, device=self.device)
original_height, original_width = image.shape[-2:]
aspect_ratio = original_width / original_height
if num_tokens is None:
min_tokens, max_tokens = self.num_tokens_range
num_tokens = int(min_tokens + (resolution_level / 9) * (max_tokens - min_tokens))
with torch.autocast(device_type=self.device.type, dtype=torch.float16, enabled=use_fp16 and self.dtype != torch.float16):
output = self.forward(image, num_tokens)
points, mask = output['points'], output['mask']
# Always process the output in fp32 precision
with torch.autocast(device_type=self.device.type, dtype=torch.float32):
points, mask, fov_x = map(lambda x: x.float() if isinstance(x, torch.Tensor) else x, [points, mask, fov_x])
mask_binary = mask > self.mask_threshold
# Get camera-space point map. (Focal here is the focal length relative to half the image diagonal)
if fov_x is None:
focal, shift = recover_focal_shift(points, mask_binary)
else:
focal = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 / torch.tan(torch.deg2rad(torch.as_tensor(fov_x, device=points.device, dtype=points.dtype) / 2))
if focal.ndim == 0:
focal = focal[None].expand(points.shape[0])
_, shift = recover_focal_shift(points, mask_binary, focal=focal)
fx = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5 / aspect_ratio
fy = focal / 2 * (1 + aspect_ratio ** 2) ** 0.5
intrinsics = utils3d.torch.intrinsics_from_focal_center(fx, fy, 0.5, 0.5)
depth = points[..., 2] + shift[..., None, None]
# If projection constraint is forced, recompute the point map using the actual depth map
if force_projection:
points = utils3d.torch.depth_to_points(depth, intrinsics=intrinsics)
else:
points = points + torch.stack([torch.zeros_like(shift), torch.zeros_like(shift), shift], dim=-1)[..., None, None, :]
# Apply mask if needed
if apply_mask:
points = torch.where(mask_binary[..., None], points, torch.inf)
depth = torch.where(mask_binary, depth, torch.inf)
return_dict = {
'points': points,
'intrinsics': intrinsics,
'depth': depth,
'mask': mask_binary,
"mask_prob": mask,
}
if omit_batch_dim:
return_dict = {k: v.squeeze(0) for k, v in return_dict.items()}
return return_dict