|
import torch |
|
from torchvision.datasets.utils import download_file_from_google_drive |
|
|
|
from matching import BaseMatcher, WEIGHTS_DIR |
|
|
|
|
|
class xFeatSteerersMatcher(BaseMatcher): |
|
""" |
|
Reference for perm steerer: https://colab.research.google.com/drive/1ZFifMqUAOQhky1197-WAquEV1K-LhDYP?usp=sharing |
|
Reference for learned steerer: https://colab.research.google.com/drive/1sCqgi3yo3OuxA8VX_jPUt5ImHDmEajsZ?usp=sharing |
|
""" |
|
steer_permutations = [ |
|
torch.arange(64).reshape(4, 16).roll(k, dims=0).reshape(64) |
|
for k in range(4) |
|
] |
|
|
|
perm_weights_gdrive_id = "1nzYg4dmkOAZPi4sjOGpQnawMoZSXYXHt" |
|
perm_weights_path = WEIGHTS_DIR.joinpath("xfeat_perm_steer.pth") |
|
|
|
learned_weights_gdrive_id = "1yJtmRhPVrpbXyN7Be32-FYctmX2Oz77r" |
|
learned_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer.pth") |
|
|
|
steerer_weights_drive_id = "1Qh_5YMjK1ZIBFVFvZlTe_eyjNPrOQ2Dv" |
|
steerer_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer_steerer.pth") |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=4096, mode="sparse", steerer_type="learned", *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
if mode not in ["sparse", "semi-dense"]: |
|
raise ValueError(f'unsupported mode for xfeat: {self.mode}. Must choose from ["sparse", "semi-dense"]') |
|
|
|
self.steerer_type = steerer_type |
|
if self.steerer_type not in ["learned", "perm"]: |
|
raise ValueError(f'unsupported type for xfeat-steerer: {steerer_type}. Must choose from ["perm", "learned"]. Learned usually perofrms better.') |
|
|
|
self.model = torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=False, top_k=max_num_keypoints) |
|
self.download_weights() |
|
|
|
|
|
state_dict = torch.load(self.weights_path, map_location="cpu") |
|
for k in list(state_dict): |
|
state_dict["net." + k] = state_dict[k] |
|
del state_dict[k] |
|
self.model.load_state_dict(state_dict) |
|
self.model.to(device) |
|
|
|
if steerer_type == 'learned': |
|
self.steerer = torch.nn.Linear(64, 64, bias=False) |
|
self.steerer.weight.data = torch.load(self.steerer_weights_path, map_location='cpu')['weight'][..., 0, 0] |
|
self.steerer.eval() |
|
self.steerer.to(device) |
|
else: |
|
self.steer_permutations = [perm.to(device) for perm in self.steer_permutations] |
|
|
|
self.max_num_keypoints = max_num_keypoints |
|
self.mode = mode |
|
self.min_cossim = kwargs.get("min_cossim", 0.8 if steerer_type == "learned" else 0.9) |
|
|
|
def download_weights(self): |
|
if self.steerer_type == "perm": |
|
self.weights_path = self.perm_weights_path |
|
if not self.perm_weights_path.exists(): |
|
download_file_from_google_drive(self.perm_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.perm_weights_path.name) |
|
|
|
if self.steerer_type == "learned": |
|
self.weights_path = self.learned_weights_path |
|
if not self.learned_weights_path.exists(): |
|
download_file_from_google_drive(self.learned_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.learned_weights_path.name) |
|
if not self.steerer_weights_path.exists(): |
|
download_file_from_google_drive(self.steerer_weights_drive_id, root=WEIGHTS_DIR, filename=self.steerer_weights_path.name) |
|
|
|
def preprocess(self, img: torch.Tensor) -> torch.Tensor: |
|
img = self.model.parse_input(img) |
|
if self.device == 'cuda' and self.mode == 'semi-dense' and img.dtype == torch.uint8: |
|
img = img / 255 |
|
return img |
|
|
|
def _forward(self, img0, img1): |
|
img0, img1 = self.preprocess(img0), self.preprocess(img1) |
|
|
|
if self.mode == "semi-dense": |
|
output0 = self.model.detectAndComputeDense(img0, top_k=self.max_num_keypoints) |
|
output1 = self.model.detectAndComputeDense(img1, top_k=self.max_num_keypoints) |
|
|
|
rot0to1 = 0 |
|
idxs_list = self.model.batch_match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim) |
|
descriptors0 = output0["descriptors"].clone() |
|
for r in range(1, 4): |
|
if self.steerer_type == "learned": |
|
descriptors0 = torch.nn.functional.normalize(self.steerer(descriptors0), dim=-1) |
|
else: |
|
descriptors0 = output0["descriptors"][..., self.steer_permutations[r]] |
|
|
|
new_idxs_list = self.model.batch_match( |
|
descriptors0, |
|
output1["descriptors"], |
|
min_cossim=self.min_cossim |
|
) |
|
if len(new_idxs_list[0][0]) > len(idxs_list[0][0]): |
|
idxs_list = new_idxs_list |
|
rot0to1 = r |
|
|
|
|
|
if self.steerer_type == "learned": |
|
if rot0to1 > 0: |
|
for _ in range(4 - rot0to1): |
|
output1['descriptors'] = self.steerer(output1['descriptors']) |
|
else: |
|
output1["descriptors"] = output1["descriptors"][..., self.steer_permutations[-rot0to1]] |
|
|
|
matches = self.model.refine_matches(output0, output1, matches=idxs_list, batch_idx=0) |
|
mkpts0, mkpts1 = matches[:, :2], matches[:, 2:] |
|
|
|
else: |
|
output0 = self.model.detectAndCompute(img0, top_k=self.max_num_keypoints)[0] |
|
output1 = self.model.detectAndCompute(img1, top_k=self.max_num_keypoints)[0] |
|
|
|
idxs0, idxs1 = self.model.match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim) |
|
rot0to1 = 0 |
|
for r in range(1, 4): |
|
if self.steerer_type == "learned": |
|
output0['descriptors'] = torch.nn.functional.normalize(self.steerer(output0['descriptors']), dim=-1) |
|
output0_steered_descriptors = output0['descriptors'] |
|
else: |
|
output0_steered_descriptors = output0['descriptors'][..., self.steer_permutations[r]] |
|
|
|
new_idxs0, new_idxs1 = self.model.match( |
|
output0_steered_descriptors, |
|
output1['descriptors'], |
|
min_cossim=self.min_cossim |
|
) |
|
if len(new_idxs0) > len(idxs0): |
|
idxs0 = new_idxs0 |
|
idxs1 = new_idxs1 |
|
rot0to1 = r |
|
|
|
mkpts0, mkpts1 = output0["keypoints"][idxs0], output1["keypoints"][idxs1] |
|
|
|
return ( |
|
mkpts0, |
|
mkpts1, |
|
output0["keypoints"].squeeze(), |
|
output1["keypoints"].squeeze(), |
|
output0["descriptors"].squeeze(), |
|
output1["descriptors"].squeeze(), |
|
) |
|
|