File size: 9,147 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
import cv2
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as tfm
import warnings
from pathlib import Path
from typing import Tuple, Union


from matching.utils import to_normalized_coords, to_px_coords, to_numpy


class BaseMatcher(torch.nn.Module):
    """
    This serves as a base class for all matchers. It provides a simple interface
    for its sub-classes to implement, namely each matcher must specify its own
    __init__ and _forward methods. It also provides a common image_loader and
    homography estimator
    """

    # OpenCV default ransac params
    DEFAULT_RANSAC_ITERS = 2000
    DEFAULT_RANSAC_CONF = 0.95
    DEFAULT_REPROJ_THRESH = 3

    def __init__(self, device="cpu", **kwargs):
        super().__init__()
        self.device = device

        self.skip_ransac = False
        self.ransac_iters = kwargs.get("ransac_iters", BaseMatcher.DEFAULT_RANSAC_ITERS)
        self.ransac_conf = kwargs.get("ransac_conf", BaseMatcher.DEFAULT_RANSAC_CONF)
        self.ransac_reproj_thresh = kwargs.get("ransac_reproj_thresh", BaseMatcher.DEFAULT_REPROJ_THRESH)


    @property
    def name(self):
        return self.__class__.__name__

    @staticmethod
    def image_loader(path: Union[str, Path], resize: Union[int, Tuple], rot_angle: float = 0) -> torch.Tensor:


        warnings.warn(
            "`image_loader` is replaced by `load_image` and will be removed in a future release.",
            DeprecationWarning,
        )
        return BaseMatcher.load_image(path, resize, rot_angle)

    @staticmethod
    def load_image(path: Union[str, Path], resize: Union[int, Tuple] = None, rot_angle: float = 0) -> torch.Tensor:

        if isinstance(resize, int):
            resize = (resize, resize)
        img = tfm.ToTensor()(Image.open(path).convert("RGB"))
        if resize is not None:
            img = tfm.Resize(resize, antialias=True)(img)
        img = tfm.functional.rotate(img, rot_angle)
        return img

    def rescale_coords(
        self,
        pts: Union[np.ndarray, torch.Tensor],
        h_orig: int,
        w_orig: int,
        h_new: int,
        w_new: int,
    ) -> np.ndarray:
        """Rescale kpts coordinates from one img size to another

        Args:
            pts (np.ndarray | torch.Tensor): (N,2) array of kpts
            h_orig (int): height of original img
            w_orig (int): width of original img
            h_new (int): height of new img
            w_new (int): width of new img

        Returns:
            np.ndarray: (N,2) array of kpts in original img coordinates
        """
        return to_px_coords(to_normalized_coords(pts, h_new, w_new), h_orig, w_orig)

    @staticmethod
    def find_homography(
        points1: Union[np.ndarray, torch.Tensor],
        points2: Union[np.ndarray, torch.Tensor],
        reproj_thresh: int = DEFAULT_REPROJ_THRESH,
        num_iters: int = DEFAULT_RANSAC_ITERS,
        ransac_conf: float = DEFAULT_RANSAC_CONF,
    ):

        assert points1.shape == points2.shape
        assert points1.shape[1] == 2
        points1, points2 = to_numpy(points1), to_numpy(points2)

        H, inliers_mask = cv2.findHomography(points1, points2, cv2.USAC_MAGSAC, reproj_thresh, ransac_conf, num_iters)
        assert inliers_mask.shape[1] == 1
        inliers_mask = inliers_mask[:, 0]
        return H, inliers_mask.astype(bool)

    def process_matches(
        self, matched_kpts0: np.ndarray, matched_kpts1: np.ndarray
    ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
        """Process matches into inliers and the respective Homography using RANSAC.

        Args:
            matched_kpts0 (np.ndarray): matching kpts from img0
            matched_kpts1 (np.ndarray): matching kpts from img1

        Returns:
            Tuple[np.ndarray, np.ndarray, np.ndarray]: Homography matrix from img0 to img1, inlier kpts in img0, inlier kpts in img1
        """
        if len(matched_kpts0) < 4 or self.skip_ransac:
            return None, matched_kpts0, matched_kpts1

        H, inliers_mask = self.find_homography(
            matched_kpts0,
            matched_kpts1,
            self.ransac_reproj_thresh,
            self.ransac_iters,
            self.ransac_conf,
        )
        inlier_kpts0 = matched_kpts0[inliers_mask]
        inlier_kpts1 = matched_kpts1[inliers_mask]

        return H, inlier_kpts0, inlier_kpts1

    def preprocess(self, img: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
        """Image preprocessing for each matcher. Some matchers require grayscale, normalization, etc.
        Applied to each input img independently

        Default preprocessing is none

        Args:
            img (torch.Tensor): input image (before preprocessing)

        Returns:
            img, (H,W) (Tuple[torch.Tensor, Tuple[int, int]]): img after preprocessing, original image shape
        """
        _, h, w = img.shape
        orig_shape = h, w
        return img, orig_shape

    @torch.inference_mode()
    def forward(self, img0: Union[torch.Tensor, str, Path], img1: Union[torch.Tensor, str, Path]) -> dict:

        """
        All sub-classes implement the following interface:

        Parameters
        ----------
        img0 : torch.tensor (C x H x W) | str | Path
        img1 : torch.tensor (C x H x W) | str | Path

        Returns
        -------
        dict with keys: ['num_inliers', 'H', 'all_kpts0', 'all_kpts1', 'all_desc0', 'all_desc1',
                         'matched_kpts0', 'matched_kpts1', 'inlier_kpts0', 'inlier_kpts1']

        num_inliers : int, number of inliers after RANSAC, i.e. len(inlier_kpts0)
        H : np.array (3 x 3), the homography matrix to map matched_kpts0 to matched_kpts1
        all_kpts0 : np.ndarray (N0 x 2), all detected keypoints from img0
        all_kpts1 : np.ndarray (N1 x 2), all detected keypoints from img1
        all_desc0 : np.ndarray (N0 x D), all descriptors from img0
        all_desc1 : np.ndarray (N1 x D), all descriptors from img1
        matched_kpts0 : np.ndarray (N2 x 2), keypoints from img0 that match matched_kpts1 (pre-RANSAC)
        matched_kpts1 : np.ndarray (N2 x 2), keypoints from img1 that match matched_kpts0 (pre-RANSAC)
        inlier_kpts0 : np.ndarray (N3 x 2), filtered matched_kpts0 that fit the H model (post-RANSAC matched_kpts)
        inlier_kpts1 : np.ndarray (N3 x 2), filtered matched_kpts1 that fit the H model (post-RANSAC matched_kpts)
        """
        # Take as input a pair of images (not a batch)
        if isinstance(img0, (str, Path)):
            img0 = BaseMatcher.load_image(img0)
        if isinstance(img1, (str, Path)):
            img1 = BaseMatcher.load_image(img1)

        assert isinstance(img0, torch.Tensor)
        assert isinstance(img1, torch.Tensor)

        img0 = img0.to(self.device)
        img1 = img1.to(self.device)

        # self._forward() is implemented by the children modules
        matched_kpts0, matched_kpts1, all_kpts0, all_kpts1, all_desc0, all_desc1 = self._forward(img0, img1)

        matched_kpts0, matched_kpts1 = to_numpy(matched_kpts0), to_numpy(matched_kpts1)
        H, inlier_kpts0, inlier_kpts1 = self.process_matches(matched_kpts0, matched_kpts1)

        return {
            "num_inliers": len(inlier_kpts0),
            "H": H,
            "all_kpts0": to_numpy(all_kpts0),
            "all_kpts1": to_numpy(all_kpts1),
            "all_desc0": to_numpy(all_desc0),
            "all_desc1": to_numpy(all_desc1),
            "matched_kpts0": matched_kpts0,
            "matched_kpts1": matched_kpts1,
            "inlier_kpts0": inlier_kpts0,
            "inlier_kpts1": inlier_kpts1,
        }

    def extract(self, img: Union[str, Path, torch.Tensor]) -> dict:

        # Take as input a pair of images (not a batch)
        if isinstance(img, (str, Path)):
            img = BaseMatcher.load_image(img)

        assert isinstance(img, torch.Tensor)

        img = img.to(self.device)

        matched_kpts0, _, all_kpts0, _, all_desc0, _ = self._forward(img, img)

        kpts = matched_kpts0 if isinstance(self, EnsembleMatcher) else all_kpts0

        return {"all_kpts0": to_numpy(kpts), "all_desc0": to_numpy(all_desc0)}


class EnsembleMatcher(BaseMatcher):
    def __init__(self, matcher_names=[], device="cpu", number_of_keypoints = 2048,**kwargs):
        from matching import get_matcher

        super().__init__(device, **kwargs)

        self.matchers = [get_matcher(name, device=device, max_num_keypoints=number_of_keypoints,**kwargs) for name in matcher_names]

    def _forward(self, img0: torch.Tensor, img1: torch.Tensor) -> Tuple[np.ndarray, np.ndarray, None, None, None, None]:
        all_matched_kpts0, all_matched_kpts1 = [], []
        for matcher in self.matchers:
            matched_kpts0, matched_kpts1, _, _, _, _ = matcher._forward(img0, img1)
            all_matched_kpts0.append(to_numpy(matched_kpts0))
            all_matched_kpts1.append(to_numpy(matched_kpts1))
        all_matched_kpts0, all_matched_kpts1 = np.concatenate(all_matched_kpts0), np.concatenate(all_matched_kpts1)
        return all_matched_kpts0, all_matched_kpts1, None, None, None, None