File size: 9,147 Bytes
0a82b18 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 |
import cv2
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as tfm
import warnings
from pathlib import Path
from typing import Tuple, Union
from matching.utils import to_normalized_coords, to_px_coords, to_numpy
class BaseMatcher(torch.nn.Module):
"""
This serves as a base class for all matchers. It provides a simple interface
for its sub-classes to implement, namely each matcher must specify its own
__init__ and _forward methods. It also provides a common image_loader and
homography estimator
"""
# OpenCV default ransac params
DEFAULT_RANSAC_ITERS = 2000
DEFAULT_RANSAC_CONF = 0.95
DEFAULT_REPROJ_THRESH = 3
def __init__(self, device="cpu", **kwargs):
super().__init__()
self.device = device
self.skip_ransac = False
self.ransac_iters = kwargs.get("ransac_iters", BaseMatcher.DEFAULT_RANSAC_ITERS)
self.ransac_conf = kwargs.get("ransac_conf", BaseMatcher.DEFAULT_RANSAC_CONF)
self.ransac_reproj_thresh = kwargs.get("ransac_reproj_thresh", BaseMatcher.DEFAULT_REPROJ_THRESH)
@property
def name(self):
return self.__class__.__name__
@staticmethod
def image_loader(path: Union[str, Path], resize: Union[int, Tuple], rot_angle: float = 0) -> torch.Tensor:
warnings.warn(
"`image_loader` is replaced by `load_image` and will be removed in a future release.",
DeprecationWarning,
)
return BaseMatcher.load_image(path, resize, rot_angle)
@staticmethod
def load_image(path: Union[str, Path], resize: Union[int, Tuple] = None, rot_angle: float = 0) -> torch.Tensor:
if isinstance(resize, int):
resize = (resize, resize)
img = tfm.ToTensor()(Image.open(path).convert("RGB"))
if resize is not None:
img = tfm.Resize(resize, antialias=True)(img)
img = tfm.functional.rotate(img, rot_angle)
return img
def rescale_coords(
self,
pts: Union[np.ndarray, torch.Tensor],
h_orig: int,
w_orig: int,
h_new: int,
w_new: int,
) -> np.ndarray:
"""Rescale kpts coordinates from one img size to another
Args:
pts (np.ndarray | torch.Tensor): (N,2) array of kpts
h_orig (int): height of original img
w_orig (int): width of original img
h_new (int): height of new img
w_new (int): width of new img
Returns:
np.ndarray: (N,2) array of kpts in original img coordinates
"""
return to_px_coords(to_normalized_coords(pts, h_new, w_new), h_orig, w_orig)
@staticmethod
def find_homography(
points1: Union[np.ndarray, torch.Tensor],
points2: Union[np.ndarray, torch.Tensor],
reproj_thresh: int = DEFAULT_REPROJ_THRESH,
num_iters: int = DEFAULT_RANSAC_ITERS,
ransac_conf: float = DEFAULT_RANSAC_CONF,
):
assert points1.shape == points2.shape
assert points1.shape[1] == 2
points1, points2 = to_numpy(points1), to_numpy(points2)
H, inliers_mask = cv2.findHomography(points1, points2, cv2.USAC_MAGSAC, reproj_thresh, ransac_conf, num_iters)
assert inliers_mask.shape[1] == 1
inliers_mask = inliers_mask[:, 0]
return H, inliers_mask.astype(bool)
def process_matches(
self, matched_kpts0: np.ndarray, matched_kpts1: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Process matches into inliers and the respective Homography using RANSAC.
Args:
matched_kpts0 (np.ndarray): matching kpts from img0
matched_kpts1 (np.ndarray): matching kpts from img1
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Homography matrix from img0 to img1, inlier kpts in img0, inlier kpts in img1
"""
if len(matched_kpts0) < 4 or self.skip_ransac:
return None, matched_kpts0, matched_kpts1
H, inliers_mask = self.find_homography(
matched_kpts0,
matched_kpts1,
self.ransac_reproj_thresh,
self.ransac_iters,
self.ransac_conf,
)
inlier_kpts0 = matched_kpts0[inliers_mask]
inlier_kpts1 = matched_kpts1[inliers_mask]
return H, inlier_kpts0, inlier_kpts1
def preprocess(self, img: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Image preprocessing for each matcher. Some matchers require grayscale, normalization, etc.
Applied to each input img independently
Default preprocessing is none
Args:
img (torch.Tensor): input image (before preprocessing)
Returns:
img, (H,W) (Tuple[torch.Tensor, Tuple[int, int]]): img after preprocessing, original image shape
"""
_, h, w = img.shape
orig_shape = h, w
return img, orig_shape
@torch.inference_mode()
def forward(self, img0: Union[torch.Tensor, str, Path], img1: Union[torch.Tensor, str, Path]) -> dict:
"""
All sub-classes implement the following interface:
Parameters
----------
img0 : torch.tensor (C x H x W) | str | Path
img1 : torch.tensor (C x H x W) | str | Path
Returns
-------
dict with keys: ['num_inliers', 'H', 'all_kpts0', 'all_kpts1', 'all_desc0', 'all_desc1',
'matched_kpts0', 'matched_kpts1', 'inlier_kpts0', 'inlier_kpts1']
num_inliers : int, number of inliers after RANSAC, i.e. len(inlier_kpts0)
H : np.array (3 x 3), the homography matrix to map matched_kpts0 to matched_kpts1
all_kpts0 : np.ndarray (N0 x 2), all detected keypoints from img0
all_kpts1 : np.ndarray (N1 x 2), all detected keypoints from img1
all_desc0 : np.ndarray (N0 x D), all descriptors from img0
all_desc1 : np.ndarray (N1 x D), all descriptors from img1
matched_kpts0 : np.ndarray (N2 x 2), keypoints from img0 that match matched_kpts1 (pre-RANSAC)
matched_kpts1 : np.ndarray (N2 x 2), keypoints from img1 that match matched_kpts0 (pre-RANSAC)
inlier_kpts0 : np.ndarray (N3 x 2), filtered matched_kpts0 that fit the H model (post-RANSAC matched_kpts)
inlier_kpts1 : np.ndarray (N3 x 2), filtered matched_kpts1 that fit the H model (post-RANSAC matched_kpts)
"""
# Take as input a pair of images (not a batch)
if isinstance(img0, (str, Path)):
img0 = BaseMatcher.load_image(img0)
if isinstance(img1, (str, Path)):
img1 = BaseMatcher.load_image(img1)
assert isinstance(img0, torch.Tensor)
assert isinstance(img1, torch.Tensor)
img0 = img0.to(self.device)
img1 = img1.to(self.device)
# self._forward() is implemented by the children modules
matched_kpts0, matched_kpts1, all_kpts0, all_kpts1, all_desc0, all_desc1 = self._forward(img0, img1)
matched_kpts0, matched_kpts1 = to_numpy(matched_kpts0), to_numpy(matched_kpts1)
H, inlier_kpts0, inlier_kpts1 = self.process_matches(matched_kpts0, matched_kpts1)
return {
"num_inliers": len(inlier_kpts0),
"H": H,
"all_kpts0": to_numpy(all_kpts0),
"all_kpts1": to_numpy(all_kpts1),
"all_desc0": to_numpy(all_desc0),
"all_desc1": to_numpy(all_desc1),
"matched_kpts0": matched_kpts0,
"matched_kpts1": matched_kpts1,
"inlier_kpts0": inlier_kpts0,
"inlier_kpts1": inlier_kpts1,
}
def extract(self, img: Union[str, Path, torch.Tensor]) -> dict:
# Take as input a pair of images (not a batch)
if isinstance(img, (str, Path)):
img = BaseMatcher.load_image(img)
assert isinstance(img, torch.Tensor)
img = img.to(self.device)
matched_kpts0, _, all_kpts0, _, all_desc0, _ = self._forward(img, img)
kpts = matched_kpts0 if isinstance(self, EnsembleMatcher) else all_kpts0
return {"all_kpts0": to_numpy(kpts), "all_desc0": to_numpy(all_desc0)}
class EnsembleMatcher(BaseMatcher):
def __init__(self, matcher_names=[], device="cpu", number_of_keypoints = 2048,**kwargs):
from matching import get_matcher
super().__init__(device, **kwargs)
self.matchers = [get_matcher(name, device=device, max_num_keypoints=number_of_keypoints,**kwargs) for name in matcher_names]
def _forward(self, img0: torch.Tensor, img1: torch.Tensor) -> Tuple[np.ndarray, np.ndarray, None, None, None, None]:
all_matched_kpts0, all_matched_kpts1 = [], []
for matcher in self.matchers:
matched_kpts0, matched_kpts1, _, _, _, _ = matcher._forward(img0, img1)
all_matched_kpts0.append(to_numpy(matched_kpts0))
all_matched_kpts1.append(to_numpy(matched_kpts1))
all_matched_kpts0, all_matched_kpts1 = np.concatenate(all_matched_kpts0), np.concatenate(all_matched_kpts1)
return all_matched_kpts0, all_matched_kpts1, None, None, None, None
|