Pawel Piwowarski
init commit
0a82b18
import cv2
import torch
import numpy as np
from PIL import Image
import torchvision.transforms as tfm
import warnings
from pathlib import Path
from typing import Tuple, Union
from matching.utils import to_normalized_coords, to_px_coords, to_numpy
class BaseMatcher(torch.nn.Module):
"""
This serves as a base class for all matchers. It provides a simple interface
for its sub-classes to implement, namely each matcher must specify its own
__init__ and _forward methods. It also provides a common image_loader and
homography estimator
"""
# OpenCV default ransac params
DEFAULT_RANSAC_ITERS = 2000
DEFAULT_RANSAC_CONF = 0.95
DEFAULT_REPROJ_THRESH = 3
def __init__(self, device="cpu", **kwargs):
super().__init__()
self.device = device
self.skip_ransac = False
self.ransac_iters = kwargs.get("ransac_iters", BaseMatcher.DEFAULT_RANSAC_ITERS)
self.ransac_conf = kwargs.get("ransac_conf", BaseMatcher.DEFAULT_RANSAC_CONF)
self.ransac_reproj_thresh = kwargs.get("ransac_reproj_thresh", BaseMatcher.DEFAULT_REPROJ_THRESH)
@property
def name(self):
return self.__class__.__name__
@staticmethod
def image_loader(path: Union[str, Path], resize: Union[int, Tuple], rot_angle: float = 0) -> torch.Tensor:
warnings.warn(
"`image_loader` is replaced by `load_image` and will be removed in a future release.",
DeprecationWarning,
)
return BaseMatcher.load_image(path, resize, rot_angle)
@staticmethod
def load_image(path: Union[str, Path], resize: Union[int, Tuple] = None, rot_angle: float = 0) -> torch.Tensor:
if isinstance(resize, int):
resize = (resize, resize)
img = tfm.ToTensor()(Image.open(path).convert("RGB"))
if resize is not None:
img = tfm.Resize(resize, antialias=True)(img)
img = tfm.functional.rotate(img, rot_angle)
return img
def rescale_coords(
self,
pts: Union[np.ndarray, torch.Tensor],
h_orig: int,
w_orig: int,
h_new: int,
w_new: int,
) -> np.ndarray:
"""Rescale kpts coordinates from one img size to another
Args:
pts (np.ndarray | torch.Tensor): (N,2) array of kpts
h_orig (int): height of original img
w_orig (int): width of original img
h_new (int): height of new img
w_new (int): width of new img
Returns:
np.ndarray: (N,2) array of kpts in original img coordinates
"""
return to_px_coords(to_normalized_coords(pts, h_new, w_new), h_orig, w_orig)
@staticmethod
def find_homography(
points1: Union[np.ndarray, torch.Tensor],
points2: Union[np.ndarray, torch.Tensor],
reproj_thresh: int = DEFAULT_REPROJ_THRESH,
num_iters: int = DEFAULT_RANSAC_ITERS,
ransac_conf: float = DEFAULT_RANSAC_CONF,
):
assert points1.shape == points2.shape
assert points1.shape[1] == 2
points1, points2 = to_numpy(points1), to_numpy(points2)
H, inliers_mask = cv2.findHomography(points1, points2, cv2.USAC_MAGSAC, reproj_thresh, ransac_conf, num_iters)
assert inliers_mask.shape[1] == 1
inliers_mask = inliers_mask[:, 0]
return H, inliers_mask.astype(bool)
def process_matches(
self, matched_kpts0: np.ndarray, matched_kpts1: np.ndarray
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""Process matches into inliers and the respective Homography using RANSAC.
Args:
matched_kpts0 (np.ndarray): matching kpts from img0
matched_kpts1 (np.ndarray): matching kpts from img1
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: Homography matrix from img0 to img1, inlier kpts in img0, inlier kpts in img1
"""
if len(matched_kpts0) < 4 or self.skip_ransac:
return None, matched_kpts0, matched_kpts1
H, inliers_mask = self.find_homography(
matched_kpts0,
matched_kpts1,
self.ransac_reproj_thresh,
self.ransac_iters,
self.ransac_conf,
)
inlier_kpts0 = matched_kpts0[inliers_mask]
inlier_kpts1 = matched_kpts1[inliers_mask]
return H, inlier_kpts0, inlier_kpts1
def preprocess(self, img: torch.Tensor) -> Tuple[torch.Tensor, Tuple[int, int]]:
"""Image preprocessing for each matcher. Some matchers require grayscale, normalization, etc.
Applied to each input img independently
Default preprocessing is none
Args:
img (torch.Tensor): input image (before preprocessing)
Returns:
img, (H,W) (Tuple[torch.Tensor, Tuple[int, int]]): img after preprocessing, original image shape
"""
_, h, w = img.shape
orig_shape = h, w
return img, orig_shape
@torch.inference_mode()
def forward(self, img0: Union[torch.Tensor, str, Path], img1: Union[torch.Tensor, str, Path]) -> dict:
"""
All sub-classes implement the following interface:
Parameters
----------
img0 : torch.tensor (C x H x W) | str | Path
img1 : torch.tensor (C x H x W) | str | Path
Returns
-------
dict with keys: ['num_inliers', 'H', 'all_kpts0', 'all_kpts1', 'all_desc0', 'all_desc1',
'matched_kpts0', 'matched_kpts1', 'inlier_kpts0', 'inlier_kpts1']
num_inliers : int, number of inliers after RANSAC, i.e. len(inlier_kpts0)
H : np.array (3 x 3), the homography matrix to map matched_kpts0 to matched_kpts1
all_kpts0 : np.ndarray (N0 x 2), all detected keypoints from img0
all_kpts1 : np.ndarray (N1 x 2), all detected keypoints from img1
all_desc0 : np.ndarray (N0 x D), all descriptors from img0
all_desc1 : np.ndarray (N1 x D), all descriptors from img1
matched_kpts0 : np.ndarray (N2 x 2), keypoints from img0 that match matched_kpts1 (pre-RANSAC)
matched_kpts1 : np.ndarray (N2 x 2), keypoints from img1 that match matched_kpts0 (pre-RANSAC)
inlier_kpts0 : np.ndarray (N3 x 2), filtered matched_kpts0 that fit the H model (post-RANSAC matched_kpts)
inlier_kpts1 : np.ndarray (N3 x 2), filtered matched_kpts1 that fit the H model (post-RANSAC matched_kpts)
"""
# Take as input a pair of images (not a batch)
if isinstance(img0, (str, Path)):
img0 = BaseMatcher.load_image(img0)
if isinstance(img1, (str, Path)):
img1 = BaseMatcher.load_image(img1)
assert isinstance(img0, torch.Tensor)
assert isinstance(img1, torch.Tensor)
img0 = img0.to(self.device)
img1 = img1.to(self.device)
# self._forward() is implemented by the children modules
matched_kpts0, matched_kpts1, all_kpts0, all_kpts1, all_desc0, all_desc1 = self._forward(img0, img1)
matched_kpts0, matched_kpts1 = to_numpy(matched_kpts0), to_numpy(matched_kpts1)
H, inlier_kpts0, inlier_kpts1 = self.process_matches(matched_kpts0, matched_kpts1)
return {
"num_inliers": len(inlier_kpts0),
"H": H,
"all_kpts0": to_numpy(all_kpts0),
"all_kpts1": to_numpy(all_kpts1),
"all_desc0": to_numpy(all_desc0),
"all_desc1": to_numpy(all_desc1),
"matched_kpts0": matched_kpts0,
"matched_kpts1": matched_kpts1,
"inlier_kpts0": inlier_kpts0,
"inlier_kpts1": inlier_kpts1,
}
def extract(self, img: Union[str, Path, torch.Tensor]) -> dict:
# Take as input a pair of images (not a batch)
if isinstance(img, (str, Path)):
img = BaseMatcher.load_image(img)
assert isinstance(img, torch.Tensor)
img = img.to(self.device)
matched_kpts0, _, all_kpts0, _, all_desc0, _ = self._forward(img, img)
kpts = matched_kpts0 if isinstance(self, EnsembleMatcher) else all_kpts0
return {"all_kpts0": to_numpy(kpts), "all_desc0": to_numpy(all_desc0)}
class EnsembleMatcher(BaseMatcher):
def __init__(self, matcher_names=[], device="cpu", number_of_keypoints = 2048,**kwargs):
from matching import get_matcher
super().__init__(device, **kwargs)
self.matchers = [get_matcher(name, device=device, max_num_keypoints=number_of_keypoints,**kwargs) for name in matcher_names]
def _forward(self, img0: torch.Tensor, img1: torch.Tensor) -> Tuple[np.ndarray, np.ndarray, None, None, None, None]:
all_matched_kpts0, all_matched_kpts1 = [], []
for matcher in self.matchers:
matched_kpts0, matched_kpts1, _, _, _, _ = matcher._forward(img0, img1)
all_matched_kpts0.append(to_numpy(matched_kpts0))
all_matched_kpts1.append(to_numpy(matched_kpts1))
all_matched_kpts0, all_matched_kpts1 = np.concatenate(all_matched_kpts0), np.concatenate(all_matched_kpts1)
return all_matched_kpts0, all_matched_kpts1, None, None, None, None