File size: 5,869 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
import torch
import numpy as np

from omegaconf import OmegaConf
import torchvision.transforms as tfm

from matching import get_matcher, BaseMatcher, THIRD_PARTY_DIR
from matching.utils import to_numpy, to_tensor, load_module, add_to_path

BASE_PATH = THIRD_PARTY_DIR.joinpath("keypt2subpx")
add_to_path(BASE_PATH)

load_module("gluefactory", BASE_PATH.joinpath("submodules/glue_factory/gluefactory/__init__.py"))
from dataprocess.superpoint_densescore import *

add_to_path(THIRD_PARTY_DIR.joinpath("LightGlue"))
from lightglue import LightGlue
from lightglue.utils import rbd, batch_to_device


class Keypt2SubpxMatcher(BaseMatcher):
    detector_name2matcher_name = {
        "splg": "superpoint-lg",
        "aliked": "aliked-lg",
        "xfeat": "xfeat",
        "xfeat-lg": "xfeat-lg",
        "dedode": "dedode",
    }

    def __init__(self, device="cpu", detector_name: str | None = None, **kwargs):
        super().__init__(device, **kwargs)

        matcher_name = self.detector_name2matcher_name[detector_name]
        self.detector_name = detector_name
        if detector_name == "splg":
            self.matcher = SuperPointDense(self.device)
        else:
            self.matcher = get_matcher(matcher_name, device=device, **kwargs)

        self.keypt2subpx = self.load_refiner(detector_name.split("-")[0])

    def load_refiner(self, detector: str) -> torch.nn.Module:
        assert detector in ["splg", "aliked", "xfeat", "dedode"]
        return (
            torch.hub.load("KimSinjeong/keypt2subpx", "Keypt2Subpx", pretrained=True, detector=detector, verbose=False)
            .eval()
            .to(self.device)
        )

    def get_match_idxs(self, mkpts: np.ndarray | torch.Tensor, kpts: np.ndarray | torch.Tensor) -> np.ndarray:
        idxs = []
        kpts = to_numpy(kpts)

        for mkpt in to_numpy(mkpts):
            idx = np.flatnonzero(np.all(kpts == mkpt, axis=1)).squeeze().item()
            idxs.append(idx)
        return np.asarray(idxs)

    def get_scoremap(self, img=None, idx=None):
        assert img is not None or idx is not None, "Must provide either image or idx"
        if self.detector_name in ["xfeat", "dedode"]:
            return None
        elif self.detector_name == "aliked":
            # https://github.com/cvg/LightGlue/blob/edb2b838efb2ecfe3f88097c5fad9887d95aedad/lightglue/aliked.py#L707
            return self.matcher.extractor.extract_dense_map(img[None, ...])[-1].squeeze(0)
        elif self.detector_name == "splg":
            return self.matcher.get_scoremap(idx)

    def _forward(self, img0, img1):
        mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1 = self.matcher._forward(img0, img1)
        if len(mkpts0):  # only run subpx refinement if kpts are found
            matching_idxs0, matching_idxs1 = self.get_match_idxs(mkpts0, keypoints0), self.get_match_idxs(
                mkpts1, keypoints1
            )
            mdesc0, mdesc1 = descriptors0[matching_idxs0], descriptors1[matching_idxs1]

            scores0, scores1 = self.get_scoremap(img0, 0), self.get_scoremap(img1, 1)
            mkpts0, mkpts1 = self.keypt2subpx(
                to_tensor(mkpts0, self.device),
                to_tensor(mkpts1, self.device),
                img0,
                img1,
                mdesc0,
                mdesc1,
                scores0,
                scores1,
            )
        return mkpts0, mkpts1, keypoints0, keypoints1, descriptors0, descriptors1


class SuperPointDense(BaseMatcher):
    # SuperPoint, with Dense Scoremap for Keypt2Subpx refinement
    modelconf = {
        "name": "two_view_pipeline",
        "extractor": {
            "name": "superpoint_densescore",
            "max_num_keypoints": 2048,
            "force_num_keypoints": False,
            "detection_threshold": 0.0,
            "nms_radius": 3,
            "remove_borders": 3,
            "trainable": False,
        },
        "matcher": {
            "name": "matchers.lightglue_wrapper",
            "weights": "superpoint",
            "depth_confidence": -1,
            "width_confidence": -1,
            "filter_threshold": 0.1,
            "trainable": False,
        },
        "ground_truth": {"name": "matchers.depth_matcher", "th_positive": 3, "th_negative": 5, "th_epi": 5},
        "allow_no_extract": True,
    }

    def __init__(self, device="cpu", **kwargs):
        super().__init__(device, **kwargs)

        self.config = OmegaConf.create(self.modelconf)
        self.extractor = SuperPoint(self.config).to(self.device).eval()
        self.matcher = LightGlue(features="superpoint", depth_confidence=-1, width_confidence=-1).to(self.device)
        self.scoremaps = {}

    def preprocess(self, img):
        return tfm.Grayscale()(img).unsqueeze(0)

    def get_scoremap(self, idx):
        return self.scoremaps[idx]

    def _forward(self, img0, img1):
        img0 = self.preprocess(img0)
        img1 = self.preprocess(img1)

        feats0 = self.extractor({"image": img0})
        feats1 = self.extractor({"image": img1})

        self.scoremaps[0] = feats0["keypoint_scores"]
        self.scoremaps[1] = feats1["keypoint_scores"]

        # requires keys ['keypoints', 'keypoint_scores', 'descriptors', 'image_size']
        matches01 = self.matcher({"image0": feats0, "image1": feats1})
        data = [feats0, feats1, matches01]
        # remove batch dim and move to target device
        feats0, feats1, matches01 = [batch_to_device(rbd(x), self.device) for x in data]

        kpts0, kpts1, matches = (
            feats0["keypoints"],
            feats1["keypoints"],
            matches01["matches"],
        )

        desc0 = feats0["descriptors"]
        desc1 = feats1["descriptors"]

        mkpts0, mkpts1 = kpts0[matches[..., 0]], kpts1[matches[..., 1]]

        return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1