File size: 5,869 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 |
import torch
import numpy as np
from omegaconf import OmegaConf
import torchvision.transforms as tfm
from matching import get_matcher, BaseMatcher, THIRD_PARTY_DIR
from matching.utils import to_numpy, to_tensor, load_module, add_to_path
BASE_PATH = THIRD_PARTY_DIR.joinpath("keypt2subpx")
add_to_path(BASE_PATH)
load_module("gluefactory", BASE_PATH.joinpath("submodules/glue_factory/gluefactory/__init__.py"))
from dataprocess.superpoint_densescore import *
add_to_path(THIRD_PARTY_DIR.joinpath("LightGlue"))
from lightglue import LightGlue
from lightglue.utils import rbd, batch_to_device
class Keypt2SubpxMatcher(BaseMatcher):
detector_name2matcher_name = {
"splg": "superpoint-lg",
"aliked": "aliked-lg",
"xfeat": "xfeat",
"xfeat-lg": "xfeat-lg",
"dedode": "dedode",
}
def __init__(self, device="cpu", detector_name: str | None = None, **kwargs):
super().__init__(device, **kwargs)
matcher_name = self.detector_name2matcher_name[detector_name]
self.detector_name = detector_name
if detector_name == "splg":
self.matcher = SuperPointDense(self.device)
else:
self.matcher = get_matcher(matcher_name, device=device, **kwargs)
self.keypt2subpx = self.load_refiner(detector_name.split("-")[0])
def load_refiner(self, detector: str) -> torch.nn.Module:
assert detector in ["splg", "aliked", "xfeat", "dedode"]
return (
torch.hub.load("KimSinjeong/keypt2subpx", "Keypt2Subpx", pretrained=True, detector=detector, verbose=False)
.eval()
.to(self.device)
)
def get_match_idxs(self, mkpts: np.ndarray | torch.Tensor, kpts: np.ndarray | torch.Tensor) -> np.ndarray:
idxs = []
kpts = to_numpy(kpts)
for mkpt in to_numpy(mkpts):
idx = np.flatnonzero(np.all(kpts == mkpt, axis=1)).squeeze().item()
idxs.append(idx)
return np.asarray(idxs)
def get_scoremap(self, img=None, idx=None):
assert img is not None or idx is not None, "Must provide either image or idx"
if self.detector_name in ["xfeat", "dedode"]:
return None
elif self.detector_name == "aliked":
# https://github.com/cvg/LightGlue/blob/edb2b838efb2ecfe3f88097c5fad9887d95aedad/lightglue/aliked.py#L707
return self.matcher.extractor.extract_dense_map(img[None, ...])[-1].squeeze(0)
elif self.detector_name == "splg":
return self.matcher.get_scoremap(idx)
def _forward(self, img0, img1):
mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1 = self.matcher._forward(img0, img1)
if len(mkpts0): # only run subpx refinement if kpts are found
matching_idxs0, matching_idxs1 = self.get_match_idxs(mkpts0, keypoints0), self.get_match_idxs(
mkpts1, keypoints1
)
mdesc0, mdesc1 = descriptors0[matching_idxs0], descriptors1[matching_idxs1]
scores0, scores1 = self.get_scoremap(img0, 0), self.get_scoremap(img1, 1)
mkpts0, mkpts1 = self.keypt2subpx(
to_tensor(mkpts0, self.device),
to_tensor(mkpts1, self.device),
img0,
img1,
mdesc0,
mdesc1,
scores0,
scores1,
)
return mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1
class SuperPointDense(BaseMatcher):
# SuperPoint, with Dense Scoremap for Keypt2Subpx refinement
modelconf = {
"name": "two_view_pipeline",
"extractor": {
"name": "superpoint_densescore",
"max_num_keypoints": 2048,
"force_num_keypoints": False,
"detection_threshold": 0.0,
"nms_radius": 3,
"remove_borders": 3,
"trainable": False,
},
"matcher": {
"name": "matchers.lightglue_wrapper",
"weights": "superpoint",
"depth_confidence": -1,
"width_confidence": -1,
"filter_threshold": 0.1,
"trainable": False,
},
"ground_truth": {"name": "matchers.depth_matcher", "th_positive": 3, "th_negative": 5, "th_epi": 5},
"allow_no_extract": True,
}
def __init__(self, device="cpu", **kwargs):
super().__init__(device, **kwargs)
self.config = OmegaConf.create(self.modelconf)
self.extractor = SuperPoint(self.config).to(self.device).eval()
self.matcher = LightGlue(features="superpoint", depth_confidence=-1, width_confidence=-1).to(self.device)
self.scoremaps = {}
def preprocess(self, img):
return tfm.Grayscale()(img).unsqueeze(0)
def get_scoremap(self, idx):
return self.scoremaps[idx]
def _forward(self, img0, img1):
img0 = self.preprocess(img0)
img1 = self.preprocess(img1)
feats0 = self.extractor({"image": img0})
feats1 = self.extractor({"image": img1})
self.scoremaps[0] = feats0["keypoint_scores"]
self.scoremaps[1] = feats1["keypoint_scores"]
# requires keys ['keypoints', 'keypoint_scores', 'descriptors', 'image_size']
matches01 = self.matcher({"image0": feats0, "image1": feats1})
data = [feats0, feats1, matches01]
# remove batch dim and move to target device
feats0, feats1, matches01 = [batch_to_device(rbd(x), self.device) for x in data]
kpts0, kpts1, matches = (
feats0["keypoints"],
feats1["keypoints"],
matches01["matches"],
)
desc0 = feats0["descriptors"]
desc1 = feats1["descriptors"]
mkpts0, mkpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1
|