Pawel Piwowarski
init commit
0a82b18
import torch
import numpy as np
from omegaconf import OmegaConf
import torchvision.transforms as tfm
from matching import get_matcher, BaseMatcher, THIRD_PARTY_DIR
from matching.utils import to_numpy, to_tensor, load_module, add_to_path
BASE_PATH = THIRD_PARTY_DIR.joinpath("keypt2subpx")
add_to_path(BASE_PATH)
load_module("gluefactory", BASE_PATH.joinpath("submodules/glue_factory/gluefactory/__init__.py"))
from dataprocess.superpoint_densescore import *
add_to_path(THIRD_PARTY_DIR.joinpath("LightGlue"))
from lightglue import LightGlue
from lightglue.utils import rbd, batch_to_device
class Keypt2SubpxMatcher(BaseMatcher):
detector_name2matcher_name = {
"splg": "superpoint-lg",
"aliked": "aliked-lg",
"xfeat": "xfeat",
"xfeat-lg": "xfeat-lg",
"dedode": "dedode",
}
def __init__(self, device="cpu", detector_name: str | None = None, **kwargs):
super().__init__(device, **kwargs)
matcher_name = self.detector_name2matcher_name[detector_name]
self.detector_name = detector_name
if detector_name == "splg":
self.matcher = SuperPointDense(self.device)
else:
self.matcher = get_matcher(matcher_name, device=device, **kwargs)
self.keypt2subpx = self.load_refiner(detector_name.split("-")[0])
def load_refiner(self, detector: str) -> torch.nn.Module:
assert detector in ["splg", "aliked", "xfeat", "dedode"]
return (
torch.hub.load("KimSinjeong/keypt2subpx", "Keypt2Subpx", pretrained=True, detector=detector, verbose=False)
.eval()
.to(self.device)
)
def get_match_idxs(self, mkpts: np.ndarray | torch.Tensor, kpts: np.ndarray | torch.Tensor) -> np.ndarray:
idxs = []
kpts = to_numpy(kpts)
for mkpt in to_numpy(mkpts):
idx = np.flatnonzero(np.all(kpts == mkpt, axis=1)).squeeze().item()
idxs.append(idx)
return np.asarray(idxs)
def get_scoremap(self, img=None, idx=None):
assert img is not None or idx is not None, "Must provide either image or idx"
if self.detector_name in ["xfeat", "dedode"]:
return None
elif self.detector_name == "aliked":
# https://github.com/cvg/LightGlue/blob/edb2b838efb2ecfe3f88097c5fad9887d95aedad/lightglue/aliked.py#L707
return self.matcher.extractor.extract_dense_map(img[None, ...])[-1].squeeze(0)
elif self.detector_name == "splg":
return self.matcher.get_scoremap(idx)
def _forward(self, img0, img1):
mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1 = self.matcher._forward(img0, img1)
if len(mkpts0): # only run subpx refinement if kpts are found
matching_idxs0, matching_idxs1 = self.get_match_idxs(mkpts0, keypoints0), self.get_match_idxs(
mkpts1, keypoints1
)
mdesc0, mdesc1 = descriptors0[matching_idxs0], descriptors1[matching_idxs1]
scores0, scores1 = self.get_scoremap(img0, 0), self.get_scoremap(img1, 1)
mkpts0, mkpts1 = self.keypt2subpx(
to_tensor(mkpts0, self.device),
to_tensor(mkpts1, self.device),
img0,
img1,
mdesc0,
mdesc1,
scores0,
scores1,
)
return mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1
class SuperPointDense(BaseMatcher):
# SuperPoint, with Dense Scoremap for Keypt2Subpx refinement
modelconf = {
"name": "two_view_pipeline",
"extractor": {
"name": "superpoint_densescore",
"max_num_keypoints": 2048,
"force_num_keypoints": False,
"detection_threshold": 0.0,
"nms_radius": 3,
"remove_borders": 3,
"trainable": False,
},
"matcher": {
"name": "matchers.lightglue_wrapper",
"weights": "superpoint",
"depth_confidence": -1,
"width_confidence": -1,
"filter_threshold": 0.1,
"trainable": False,
},
"ground_truth": {"name": "matchers.depth_matcher", "th_positive": 3, "th_negative": 5, "th_epi": 5},
"allow_no_extract": True,
}
def __init__(self, device="cpu", **kwargs):
super().__init__(device, **kwargs)
self.config = OmegaConf.create(self.modelconf)
self.extractor = SuperPoint(self.config).to(self.device).eval()
self.matcher = LightGlue(features="superpoint", depth_confidence=-1, width_confidence=-1).to(self.device)
self.scoremaps = {}
def preprocess(self, img):
return tfm.Grayscale()(img).unsqueeze(0)
def get_scoremap(self, idx):
return self.scoremaps[idx]
def _forward(self, img0, img1):
img0 = self.preprocess(img0)
img1 = self.preprocess(img1)
feats0 = self.extractor({"image": img0})
feats1 = self.extractor({"image": img1})
self.scoremaps[0] = feats0["keypoint_scores"]
self.scoremaps[1] = feats1["keypoint_scores"]
# requires keys ['keypoints', 'keypoint_scores', 'descriptors', 'image_size']
matches01 = self.matcher({"image0": feats0, "image1": feats1})
data = [feats0, feats1, matches01]
# remove batch dim and move to target device
feats0, feats1, matches01 = [batch_to_device(rbd(x), self.device) for x in data]
kpts0, kpts1, matches = (
feats0["keypoints"],
feats1["keypoints"],
matches01["matches"],
)
desc0 = feats0["descriptors"]
desc1 = feats1["descriptors"]
mkpts0, mkpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1