File size: 2,879 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
import cv2
import numpy as np
import torch

from matching import BaseMatcher
from matching.utils import to_numpy


class HandcraftedBaseMatcher(BaseMatcher):
    """
    This class is the parent for all methods that use a handcrafted detector/descriptor,
    It implements the forward which is the same regardless of the feature extractor of choice.
    Therefore this class should *NOT* be instatiated, as it needs its children to define
    the extractor/detector.
    """

    def __init__(self, device="cpu", **kwargs):
        super().__init__(device, **kwargs)

    @staticmethod
    def preprocess(im_tensor: torch.Tensor) -> np.ndarray:
        # convert tensors to np 255-based for openCV
        im_arr = to_numpy(im_tensor).transpose(1, 2, 0)
        im = cv2.cvtColor(im_arr, cv2.COLOR_RGB2GRAY)
        im = cv2.normalize(im, None, 0, 255, cv2.NORM_MINMAX).astype("uint8")

        return im

    def _forward(self, img0, img1):
        """
        "det_descr" is instantiated by the subclasses.
        """
        # convert tensors to numpy 255-based for OpenCV
        img0 = self.preprocess(img0)
        img1 = self.preprocess(img1)

        # find the keypoints and descriptors with SIFT
        kp0, des0 = self.det_descr.detectAndCompute(img0, None)
        kp1, des1 = self.det_descr.detectAndCompute(img1, None)

        # BFMatcher with default params

        raw_matches = self.bf.knnMatch(des0, des1, k=self.k_neighbors)

        # Apply ratio test
        good = []
        for m, n in raw_matches:
            if m.distance < self.threshold * n.distance:
                good.append(m)

        mkpts0, mkpts1 = [], []
        for good_match in good:
            kpt_0 = np.array(kp0[good_match.queryIdx].pt)
            kpt_1 = np.array(kp1[good_match.trainIdx].pt)

            mkpts0.append(kpt_0)
            mkpts1.append(kpt_1)

        mkpts0 = np.array(mkpts0, dtype=np.float32)
        mkpts1 = np.array(mkpts1, dtype=np.float32)

        keypoints_0 = np.array([(x.pt[0], x.pt[1]) for x in kp0])
        keypoints_1 = np.array([(x.pt[0], x.pt[1]) for x in kp1])

        return mkpts0, mkpts1, keypoints_0, keypoints_1, des0, des1


class SiftNNMatcher(HandcraftedBaseMatcher):
    def __init__(self, device="cpu", max_num_keypoints=2048, lowe_thresh=0.75, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.threshold = lowe_thresh
        self.det_descr = cv2.SIFT_create(max_num_keypoints)

        self.bf = cv2.BFMatcher()
        self.k_neighbors = 2


class OrbNNMatcher(HandcraftedBaseMatcher):
    def __init__(self, device="cpu", max_num_keypoints=2048, lowe_thresh=0.75, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.threshold = lowe_thresh
        self.det_descr = cv2.ORB_create(max_num_keypoints)
        self.bf = cv2.BFMatcher(cv2.NORM_HAMMING)
        self.k_neighbors = 2