|
import os |
|
import torchvision.transforms as tfm |
|
import py3_wget |
|
import numpy as np |
|
import torch |
|
import torch.nn.functional as F |
|
from matching import BaseMatcher, WEIGHTS_DIR, THIRD_PARTY_DIR |
|
from matching.utils import resize_to_divisible, add_to_path |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("mast3r")) |
|
from typing import Tuple, Union, List, Optional |
|
from mast3r.model import AsymmetricMASt3R |
|
from mast3r.fast_nn import fast_reciprocal_NNs |
|
from dust3r.inference import inference |
|
|
|
|
|
|
|
class Mast3rMatcher(BaseMatcher): |
|
model_path = WEIGHTS_DIR.joinpath( |
|
"MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" |
|
) |
|
vit_patch_size = 16 |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.normalize = tfm.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)) |
|
|
|
self.verbose = False |
|
self.max_keypoints = max_num_keypoints |
|
|
|
self.download_weights() |
|
|
|
self.model = AsymmetricMASt3R.from_pretrained(self.model_path).to(device) |
|
|
|
@staticmethod |
|
def download_weights(): |
|
url = "https://download.europe.naverlabs.com/ComputerVision/MASt3R/MASt3R_ViTLarge_BaseDecoder_512_catmlpdpt_metric.pth" |
|
|
|
if not os.path.isfile(Mast3rMatcher.model_path): |
|
print("Downloading Master(ViT large)... (takes a while)") |
|
py3_wget.download_file(url, Mast3rMatcher.model_path) |
|
|
|
def preprocess(self, img): |
|
_, h, w = img.shape |
|
orig_shape = h, w |
|
|
|
img = resize_to_divisible(img, self.vit_patch_size) |
|
|
|
img = self.normalize(img).unsqueeze(0) |
|
|
|
return img, orig_shape |
|
|
|
def _forward(self, img0, img1): |
|
img0, img0_orig_shape = self.preprocess(img0) |
|
img1, img1_orig_shape = self.preprocess(img1) |
|
|
|
img_pair = [ |
|
{ |
|
"img": img0, |
|
"idx": 0, |
|
"instance": 0, |
|
"true_shape": np.int32([img0.shape[-2:]]), |
|
}, |
|
{ |
|
"img": img1, |
|
"idx": 1, |
|
"instance": 1, |
|
"true_shape": np.int32([img1.shape[-2:]]), |
|
}, |
|
] |
|
output = inference( |
|
[tuple(img_pair)], self.model, self.device, batch_size=1, verbose=False |
|
) |
|
view1, pred1 = output["view1"], output["pred1"] |
|
view2, pred2 = output["view2"], output["pred2"] |
|
|
|
desc1, desc2 = ( |
|
pred1["desc"].squeeze(0).detach(), |
|
pred2["desc"].squeeze(0).detach(), |
|
) |
|
|
|
|
|
matches_im0, matches_im1 = fast_reciprocal_NNs( |
|
desc1, |
|
desc2, |
|
subsample_or_initxy1=8, |
|
device=self.device, |
|
dist="dot", |
|
block_size=2**13, |
|
max_matches=self.max_keypoints, |
|
) |
|
|
|
H0, W0 = view1["true_shape"][0] |
|
valid_matches_im0 = ( |
|
(matches_im0[:, 0] >= 3) |
|
& (matches_im0[:, 0] < int(W0) - 3) |
|
& (matches_im0[:, 1] >= 3) |
|
& (matches_im0[:, 1] < int(H0) - 3) |
|
) |
|
|
|
H1, W1 = view2["true_shape"][0] |
|
valid_matches_im1 = ( |
|
(matches_im1[:, 0] >= 3) |
|
& (matches_im1[:, 0] < int(W1) - 3) |
|
& (matches_im1[:, 1] >= 3) |
|
& (matches_im1[:, 1] < int(H1) - 3) |
|
) |
|
|
|
valid_matches = valid_matches_im0 & valid_matches_im1 |
|
mkpts0, mkpts1 = matches_im0[valid_matches], matches_im1[valid_matches] |
|
|
|
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:] |
|
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) |
|
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) |
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|
|
|
|
|