File size: 3,773 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
import py3_wget
from kornia.color import rgb_to_grayscale
import shutil

from matching import BaseMatcher, THIRD_PARTY_DIR
from matching.utils import add_to_path

BASE_PATH = THIRD_PARTY_DIR.joinpath("silk")
add_to_path(BASE_PATH)


def setup_silk():
    # silk is meant to be installed with a symlink from the lib/ to silk/
    # this often doesnt work (see https://github.com/facebookresearch/silk/issues/32)
    # this solves the issue with the "ugly but works" method
    silk_dir = BASE_PATH.joinpath("silk")
    if not silk_dir.exists():
        lib_dir = BASE_PATH.joinpath("lib")
        assert lib_dir.exists() and lib_dir.is_dir()
        shutil.copytree(lib_dir, silk_dir)
    return None


try:
    from scripts.examples.common import get_model
    from silk.backbones.silk.silk import from_feature_coords_to_image_coords
    from silk.models.silk import matcher
except ModuleNotFoundError:
    setup_silk()
    from scripts.examples.common import get_model
    from silk.backbones.silk.silk import from_feature_coords_to_image_coords
    from silk.models.silk import matcher


class SilkMatcher(BaseMatcher):
    # reference: https://github.com/facebookresearch/silk/blob/main/scripts/examples/silk-inference.py
    CKPT_DOWNLOAD_SRC = "https://dl.fbaipublicfiles.com/silk/assets/models/silk/"
    CKPT_DIR = BASE_PATH.joinpath(r"assets/models/silk")

    MATCHER_POSTPROCESS_OPTIONS = ["ratio-test", "mnn", "double-softmax"]

    def __init__(
        self,
        device="cpu",
        matcher_post_processing="ratio-test",
        matcher_thresh=0.8,
        **kwargs,
    ):
        super().__init__(device, **kwargs)
        SilkMatcher.CKPT_DIR.mkdir(exist_ok=True)

        self.download_weights()

        self.model = get_model(device=device, default_outputs=("sparse_positions", "sparse_descriptors"))

        assert (
            matcher_post_processing in SilkMatcher.MATCHER_POSTPROCESS_OPTIONS
        ), f"Matcher postprocessing must be one of {SilkMatcher.MATCHER_POSTPROCESS_OPTIONS}"
        self.matcher = matcher(postprocessing=matcher_post_processing, threshold=matcher_thresh)

    def download_weights(self):
        ckpt_name = "coco-rgb-aug.ckpt"
        ckpt_path = SilkMatcher.CKPT_DIR.joinpath(ckpt_name)
        if not ckpt_path.exists():
            print(f"Downloading {ckpt_name}")
            py3_wget.download_file(SilkMatcher.CKPT_DOWNLOAD_SRC + ckpt_name, ckpt_path)

    def preprocess(self, img):
        # expects float img (0-1) with channel dim
        if img.ndim == 3:
            img = img.unsqueeze(0)
        return rgb_to_grayscale(img)

    def _forward(self, img0, img1):
        img0 = self.preprocess(img0)
        img1 = self.preprocess(img1)

        sparse_positions_0, sparse_descriptors_0 = self.model(img0)
        sparse_positions_1, sparse_descriptors_1 = self.model(img1)

        # x, y, conf
        sparse_positions_0 = from_feature_coords_to_image_coords(self.model, sparse_positions_0)
        sparse_positions_1 = from_feature_coords_to_image_coords(self.model, sparse_positions_1)

        # get matches
        matches = self.matcher(sparse_descriptors_0[0], sparse_descriptors_1[0])

        # get matching pts
        mkpts0 = sparse_positions_0[0][matches[:, 0]].detach().cpu().numpy()[:, :2]
        mkpts1 = sparse_positions_1[0][matches[:, 1]].detach().cpu().numpy()[:, :2]

        # convert kpts to col, row (x,y) order
        mkpts0 = mkpts0[:, [1, 0]]
        mkpts1 = mkpts1[:, [1, 0]]

        kpts0 = to_numpy(sparse_positions_0[0][:, :2])[:, [1, 0]]
        kpts1 = to_numpy(sparse_positions_1[0][:, :2])[:, [1, 0]]
        desc0 = to_numpy(sparse_descriptors_0[0])
        desc1 = to_numpy(sparse_descriptors_1[0])

        return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1