Pawel Piwowarski
init commit
0a82b18
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from matching import BaseMatcher, WEIGHTS_DIR
class xFeatSteerersMatcher(BaseMatcher):
"""
Reference for perm steerer: https://colab.research.google.com/drive/1ZFifMqUAOQhky1197-WAquEV1K-LhDYP?usp=sharing
Reference for learned steerer: https://colab.research.google.com/drive/1sCqgi3yo3OuxA8VX_jPUt5ImHDmEajsZ?usp=sharing
"""
steer_permutations = [
torch.arange(64).reshape(4, 16).roll(k, dims=0).reshape(64)
for k in range(4)
]
perm_weights_gdrive_id = "1nzYg4dmkOAZPi4sjOGpQnawMoZSXYXHt"
perm_weights_path = WEIGHTS_DIR.joinpath("xfeat_perm_steer.pth")
learned_weights_gdrive_id = "1yJtmRhPVrpbXyN7Be32-FYctmX2Oz77r"
learned_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer.pth")
steerer_weights_drive_id = "1Qh_5YMjK1ZIBFVFvZlTe_eyjNPrOQ2Dv"
steerer_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer_steerer.pth")
def __init__(self, device="cpu", max_num_keypoints=4096, mode="sparse", steerer_type="learned", *args, **kwargs):
super().__init__(device, **kwargs)
if mode not in ["sparse", "semi-dense"]:
raise ValueError(f'unsupported mode for xfeat: {self.mode}. Must choose from ["sparse", "semi-dense"]')
self.steerer_type = steerer_type
if self.steerer_type not in ["learned", "perm"]:
raise ValueError(f'unsupported type for xfeat-steerer: {steerer_type}. Must choose from ["perm", "learned"]. Learned usually perofrms better.')
self.model = torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=False, top_k=max_num_keypoints)
self.download_weights()
# Load xfeat-fixed-perm-steerers weights
state_dict = torch.load(self.weights_path, map_location="cpu")
for k in list(state_dict):
state_dict["net." + k] = state_dict[k]
del state_dict[k]
self.model.load_state_dict(state_dict)
self.model.to(device)
if steerer_type == 'learned':
self.steerer = torch.nn.Linear(64, 64, bias=False)
self.steerer.weight.data = torch.load(self.steerer_weights_path, map_location='cpu')['weight'][..., 0, 0]
self.steerer.eval()
self.steerer.to(device)
else:
self.steer_permutations = [perm.to(device) for perm in self.steer_permutations]
self.max_num_keypoints = max_num_keypoints
self.mode = mode
self.min_cossim = kwargs.get("min_cossim", 0.8 if steerer_type == "learned" else 0.9)
def download_weights(self):
if self.steerer_type == "perm":
self.weights_path = self.perm_weights_path
if not self.perm_weights_path.exists():
download_file_from_google_drive(self.perm_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.perm_weights_path.name)
if self.steerer_type == "learned":
self.weights_path = self.learned_weights_path
if not self.learned_weights_path.exists():
download_file_from_google_drive(self.learned_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.learned_weights_path.name)
if not self.steerer_weights_path.exists():
download_file_from_google_drive(self.steerer_weights_drive_id, root=WEIGHTS_DIR, filename=self.steerer_weights_path.name)
def preprocess(self, img: torch.Tensor) -> torch.Tensor:
img = self.model.parse_input(img)
if self.device == 'cuda' and self.mode == 'semi-dense' and img.dtype == torch.uint8:
img = img / 255 # cuda error in upsample_bilinear_2d_out_frame if img is ubyte
return img
def _forward(self, img0, img1):
img0, img1 = self.preprocess(img0), self.preprocess(img1)
if self.mode == "semi-dense":
output0 = self.model.detectAndComputeDense(img0, top_k=self.max_num_keypoints)
output1 = self.model.detectAndComputeDense(img1, top_k=self.max_num_keypoints)
rot0to1 = 0
idxs_list = self.model.batch_match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
descriptors0 = output0["descriptors"].clone()
for r in range(1, 4):
if self.steerer_type == "learned":
descriptors0 = torch.nn.functional.normalize(self.steerer(descriptors0), dim=-1)
else:
descriptors0 = output0["descriptors"][..., self.steer_permutations[r]]
new_idxs_list = self.model.batch_match(
descriptors0,
output1["descriptors"],
min_cossim=self.min_cossim
)
if len(new_idxs_list[0][0]) > len(idxs_list[0][0]):
idxs_list = new_idxs_list
rot0to1 = r
# align to first image for refinement MLP
if self.steerer_type == "learned":
if rot0to1 > 0:
for _ in range(4 - rot0to1):
output1['descriptors'] = self.steerer(output1['descriptors']) # Adding normalization here hurts performance for some reason, probably due to the way it's done during training
else:
output1["descriptors"] = output1["descriptors"][..., self.steer_permutations[-rot0to1]]
matches = self.model.refine_matches(output0, output1, matches=idxs_list, batch_idx=0)
mkpts0, mkpts1 = matches[:, :2], matches[:, 2:]
else:
output0 = self.model.detectAndCompute(img0, top_k=self.max_num_keypoints)[0]
output1 = self.model.detectAndCompute(img1, top_k=self.max_num_keypoints)[0]
idxs0, idxs1 = self.model.match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
rot0to1 = 0
for r in range(1, 4):
if self.steerer_type == "learned":
output0['descriptors'] = torch.nn.functional.normalize(self.steerer(output0['descriptors']), dim=-1)
output0_steered_descriptors = output0['descriptors']
else:
output0_steered_descriptors = output0['descriptors'][..., self.steer_permutations[r]]
new_idxs0, new_idxs1 = self.model.match(
output0_steered_descriptors,
output1['descriptors'],
min_cossim=self.min_cossim
)
if len(new_idxs0) > len(idxs0):
idxs0 = new_idxs0
idxs1 = new_idxs1
rot0to1 = r
mkpts0, mkpts1 = output0["keypoints"][idxs0], output1["keypoints"][idxs1]
return (
mkpts0,
mkpts1,
output0["keypoints"].squeeze(),
output1["keypoints"].squeeze(),
output0["descriptors"].squeeze(),
output1["descriptors"].squeeze(),
)