File size: 6,995 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 |
import torch
from torchvision.datasets.utils import download_file_from_google_drive
from matching import BaseMatcher, WEIGHTS_DIR
class xFeatSteerersMatcher(BaseMatcher):
"""
Reference for perm steerer: https://colab.research.google.com/drive/1ZFifMqUAOQhky1197-WAquEV1K-LhDYP?usp=sharing
Reference for learned steerer: https://colab.research.google.com/drive/1sCqgi3yo3OuxA8VX_jPUt5ImHDmEajsZ?usp=sharing
"""
steer_permutations = [
torch.arange(64).reshape(4, 16).roll(k, dims=0).reshape(64)
for k in range(4)
]
perm_weights_gdrive_id = "1nzYg4dmkOAZPi4sjOGpQnawMoZSXYXHt"
perm_weights_path = WEIGHTS_DIR.joinpath("xfeat_perm_steer.pth")
learned_weights_gdrive_id = "1yJtmRhPVrpbXyN7Be32-FYctmX2Oz77r"
learned_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer.pth")
steerer_weights_drive_id = "1Qh_5YMjK1ZIBFVFvZlTe_eyjNPrOQ2Dv"
steerer_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer_steerer.pth")
def __init__(self, device="cpu", max_num_keypoints=4096, mode="sparse", steerer_type="learned", *args, **kwargs):
super().__init__(device, **kwargs)
if mode not in ["sparse", "semi-dense"]:
raise ValueError(f'unsupported mode for xfeat: {self.mode}. Must choose from ["sparse", "semi-dense"]')
self.steerer_type = steerer_type
if self.steerer_type not in ["learned", "perm"]:
raise ValueError(f'unsupported type for xfeat-steerer: {steerer_type}. Must choose from ["perm", "learned"]. Learned usually perofrms better.')
self.model = torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=False, top_k=max_num_keypoints)
self.download_weights()
# Load xfeat-fixed-perm-steerers weights
state_dict = torch.load(self.weights_path, map_location="cpu")
for k in list(state_dict):
state_dict["net." + k] = state_dict[k]
del state_dict[k]
self.model.load_state_dict(state_dict)
self.model.to(device)
if steerer_type == 'learned':
self.steerer = torch.nn.Linear(64, 64, bias=False)
self.steerer.weight.data = torch.load(self.steerer_weights_path, map_location='cpu')['weight'][..., 0, 0]
self.steerer.eval()
self.steerer.to(device)
else:
self.steer_permutations = [perm.to(device) for perm in self.steer_permutations]
self.max_num_keypoints = max_num_keypoints
self.mode = mode
self.min_cossim = kwargs.get("min_cossim", 0.8 if steerer_type == "learned" else 0.9)
def download_weights(self):
if self.steerer_type == "perm":
self.weights_path = self.perm_weights_path
if not self.perm_weights_path.exists():
download_file_from_google_drive(self.perm_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.perm_weights_path.name)
if self.steerer_type == "learned":
self.weights_path = self.learned_weights_path
if not self.learned_weights_path.exists():
download_file_from_google_drive(self.learned_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.learned_weights_path.name)
if not self.steerer_weights_path.exists():
download_file_from_google_drive(self.steerer_weights_drive_id, root=WEIGHTS_DIR, filename=self.steerer_weights_path.name)
def preprocess(self, img: torch.Tensor) -> torch.Tensor:
img = self.model.parse_input(img)
if self.device == 'cuda' and self.mode == 'semi-dense' and img.dtype == torch.uint8:
img = img / 255 # cuda error in upsample_bilinear_2d_out_frame if img is ubyte
return img
def _forward(self, img0, img1):
img0, img1 = self.preprocess(img0), self.preprocess(img1)
if self.mode == "semi-dense":
output0 = self.model.detectAndComputeDense(img0, top_k=self.max_num_keypoints)
output1 = self.model.detectAndComputeDense(img1, top_k=self.max_num_keypoints)
rot0to1 = 0
idxs_list = self.model.batch_match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
descriptors0 = output0["descriptors"].clone()
for r in range(1, 4):
if self.steerer_type == "learned":
descriptors0 = torch.nn.functional.normalize(self.steerer(descriptors0), dim=-1)
else:
descriptors0 = output0["descriptors"][..., self.steer_permutations[r]]
new_idxs_list = self.model.batch_match(
descriptors0,
output1["descriptors"],
min_cossim=self.min_cossim
)
if len(new_idxs_list[0][0]) > len(idxs_list[0][0]):
idxs_list = new_idxs_list
rot0to1 = r
# align to first image for refinement MLP
if self.steerer_type == "learned":
if rot0to1 > 0:
for _ in range(4 - rot0to1):
output1['descriptors'] = self.steerer(output1['descriptors']) # Adding normalization here hurts performance for some reason, probably due to the way it's done during training
else:
output1["descriptors"] = output1["descriptors"][..., self.steer_permutations[-rot0to1]]
matches = self.model.refine_matches(output0, output1, matches=idxs_list, batch_idx=0)
mkpts0, mkpts1 = matches[:, :2], matches[:, 2:]
else:
output0 = self.model.detectAndCompute(img0, top_k=self.max_num_keypoints)[0]
output1 = self.model.detectAndCompute(img1, top_k=self.max_num_keypoints)[0]
idxs0, idxs1 = self.model.match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
rot0to1 = 0
for r in range(1, 4):
if self.steerer_type == "learned":
output0['descriptors'] = torch.nn.functional.normalize(self.steerer(output0['descriptors']), dim=-1)
output0_steered_descriptors = output0['descriptors']
else:
output0_steered_descriptors = output0['descriptors'][..., self.steer_permutations[r]]
new_idxs0, new_idxs1 = self.model.match(
output0_steered_descriptors,
output1['descriptors'],
min_cossim=self.min_cossim
)
if len(new_idxs0) > len(idxs0):
idxs0 = new_idxs0
idxs1 = new_idxs1
rot0to1 = r
mkpts0, mkpts1 = output0["keypoints"][idxs0], output1["keypoints"][idxs1]
return (
mkpts0,
mkpts1,
output0["keypoints"].squeeze(),
output1["keypoints"].squeeze(),
output0["descriptors"].squeeze(),
output1["descriptors"].squeeze(),
)
|