|
import torch |
|
import numpy as np |
|
|
|
from omegaconf import OmegaConf |
|
import torchvision.transforms as tfm |
|
|
|
from matching import get_matcher, BaseMatcher, THIRD_PARTY_DIR |
|
from matching.utils import to_numpy, to_tensor, load_module, add_to_path |
|
|
|
BASE_PATH = THIRD_PARTY_DIR.joinpath("keypt2subpx") |
|
add_to_path(BASE_PATH) |
|
|
|
load_module("gluefactory", BASE_PATH.joinpath("submodules/glue_factory/gluefactory/__init__.py")) |
|
from dataprocess.superpoint_densescore import * |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("LightGlue")) |
|
from lightglue import LightGlue |
|
from lightglue.utils import rbd, batch_to_device |
|
|
|
|
|
class Keypt2SubpxMatcher(BaseMatcher): |
|
detector_name2matcher_name = { |
|
"splg": "superpoint-lg", |
|
"aliked": "aliked-lg", |
|
"xfeat": "xfeat", |
|
"xfeat-lg": "xfeat-lg", |
|
"dedode": "dedode", |
|
} |
|
|
|
def __init__(self, device="cpu", detector_name: str | None = None, **kwargs): |
|
super().__init__(device, **kwargs) |
|
|
|
matcher_name = self.detector_name2matcher_name[detector_name] |
|
self.detector_name = detector_name |
|
if detector_name == "splg": |
|
self.matcher = SuperPointDense(self.device) |
|
else: |
|
self.matcher = get_matcher(matcher_name, device=device, **kwargs) |
|
|
|
self.keypt2subpx = self.load_refiner(detector_name.split("-")[0]) |
|
|
|
def load_refiner(self, detector: str) -> torch.nn.Module: |
|
assert detector in ["splg", "aliked", "xfeat", "dedode"] |
|
return ( |
|
torch.hub.load("KimSinjeong/keypt2subpx", "Keypt2Subpx", pretrained=True, detector=detector, verbose=False) |
|
.eval() |
|
.to(self.device) |
|
) |
|
|
|
def get_match_idxs(self, mkpts: np.ndarray | torch.Tensor, kpts: np.ndarray | torch.Tensor) -> np.ndarray: |
|
idxs = [] |
|
kpts = to_numpy(kpts) |
|
|
|
for mkpt in to_numpy(mkpts): |
|
idx = np.flatnonzero(np.all(kpts == mkpt, axis=1)).squeeze().item() |
|
idxs.append(idx) |
|
return np.asarray(idxs) |
|
|
|
def get_scoremap(self, img=None, idx=None): |
|
assert img is not None or idx is not None, "Must provide either image or idx" |
|
if self.detector_name in ["xfeat", "dedode"]: |
|
return None |
|
elif self.detector_name == "aliked": |
|
|
|
return self.matcher.extractor.extract_dense_map(img[None, ...])[-1].squeeze(0) |
|
elif self.detector_name == "splg": |
|
return self.matcher.get_scoremap(idx) |
|
|
|
def _forward(self, img0, img1): |
|
mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1 = self.matcher._forward(img0, img1) |
|
if len(mkpts0): |
|
matching_idxs0, matching_idxs1 = self.get_match_idxs(mkpts0, keypoints0), self.get_match_idxs( |
|
mkpts1, keypoints1 |
|
) |
|
mdesc0, mdesc1 = descriptors0[matching_idxs0], descriptors1[matching_idxs1] |
|
|
|
scores0, scores1 = self.get_scoremap(img0, 0), self.get_scoremap(img1, 1) |
|
mkpts0, mkpts1 = self.keypt2subpx( |
|
to_tensor(mkpts0, self.device), |
|
to_tensor(mkpts1, self.device), |
|
img0, |
|
img1, |
|
mdesc0, |
|
mdesc1, |
|
scores0, |
|
scores1, |
|
) |
|
return mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1 |
|
|
|
|
|
class SuperPointDense(BaseMatcher): |
|
|
|
modelconf = { |
|
"name": "two_view_pipeline", |
|
"extractor": { |
|
"name": "superpoint_densescore", |
|
"max_num_keypoints": 2048, |
|
"force_num_keypoints": False, |
|
"detection_threshold": 0.0, |
|
"nms_radius": 3, |
|
"remove_borders": 3, |
|
"trainable": False, |
|
}, |
|
"matcher": { |
|
"name": "matchers.lightglue_wrapper", |
|
"weights": "superpoint", |
|
"depth_confidence": -1, |
|
"width_confidence": -1, |
|
"filter_threshold": 0.1, |
|
"trainable": False, |
|
}, |
|
"ground_truth": {"name": "matchers.depth_matcher", "th_positive": 3, "th_negative": 5, "th_epi": 5}, |
|
"allow_no_extract": True, |
|
} |
|
|
|
def __init__(self, device="cpu", **kwargs): |
|
super().__init__(device, **kwargs) |
|
|
|
self.config = OmegaConf.create(self.modelconf) |
|
self.extractor = SuperPoint(self.config).to(self.device).eval() |
|
self.matcher = LightGlue(features="superpoint", depth_confidence=-1, width_confidence=-1).to(self.device) |
|
self.scoremaps = {} |
|
|
|
def preprocess(self, img): |
|
return tfm.Grayscale()(img).unsqueeze(0) |
|
|
|
def get_scoremap(self, idx): |
|
return self.scoremaps[idx] |
|
|
|
def _forward(self, img0, img1): |
|
img0 = self.preprocess(img0) |
|
img1 = self.preprocess(img1) |
|
|
|
feats0 = self.extractor({"image": img0}) |
|
feats1 = self.extractor({"image": img1}) |
|
|
|
self.scoremaps[0] = feats0["keypoint_scores"] |
|
self.scoremaps[1] = feats1["keypoint_scores"] |
|
|
|
|
|
matches01 = self.matcher({"image0": feats0, "image1": feats1}) |
|
data = [feats0, feats1, matches01] |
|
|
|
feats0, feats1, matches01 = [batch_to_device(rbd(x), self.device) for x in data] |
|
|
|
kpts0, kpts1, matches = ( |
|
feats0["keypoints"], |
|
feats1["keypoints"], |
|
matches01["matches"], |
|
) |
|
|
|
desc0 = feats0["descriptors"] |
|
desc1 = feats1["descriptors"] |
|
|
|
mkpts0, mkpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]] |
|
|
|
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1 |
|
|