File size: 6,995 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
import torch
from torchvision.datasets.utils import download_file_from_google_drive

from matching import BaseMatcher, WEIGHTS_DIR


class xFeatSteerersMatcher(BaseMatcher):
    """
    Reference for perm steerer: https://colab.research.google.com/drive/1ZFifMqUAOQhky1197-WAquEV1K-LhDYP?usp=sharing
    Reference for learned steerer: https://colab.research.google.com/drive/1sCqgi3yo3OuxA8VX_jPUt5ImHDmEajsZ?usp=sharing
    """
    steer_permutations = [
        torch.arange(64).reshape(4, 16).roll(k, dims=0).reshape(64)
        for k in range(4)
    ]

    perm_weights_gdrive_id = "1nzYg4dmkOAZPi4sjOGpQnawMoZSXYXHt"
    perm_weights_path = WEIGHTS_DIR.joinpath("xfeat_perm_steer.pth")

    learned_weights_gdrive_id = "1yJtmRhPVrpbXyN7Be32-FYctmX2Oz77r"
    learned_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer.pth")

    steerer_weights_drive_id = "1Qh_5YMjK1ZIBFVFvZlTe_eyjNPrOQ2Dv"
    steerer_weights_path = WEIGHTS_DIR.joinpath("xfeat_learn_steer_steerer.pth")

    def __init__(self, device="cpu", max_num_keypoints=4096, mode="sparse", steerer_type="learned", *args, **kwargs):
        super().__init__(device, **kwargs)
        if mode not in ["sparse", "semi-dense"]:
            raise ValueError(f'unsupported mode for xfeat: {self.mode}. Must choose from ["sparse", "semi-dense"]')

        self.steerer_type = steerer_type
        if self.steerer_type not in ["learned", "perm"]:
            raise ValueError(f'unsupported type for xfeat-steerer: {steerer_type}. Must choose from ["perm", "learned"]. Learned usually perofrms better.')

        self.model = torch.hub.load("verlab/accelerated_features", "XFeat", pretrained=False, top_k=max_num_keypoints)
        self.download_weights()

        # Load xfeat-fixed-perm-steerers weights
        state_dict = torch.load(self.weights_path, map_location="cpu")
        for k in list(state_dict):
            state_dict["net." + k] = state_dict[k]
            del state_dict[k]
        self.model.load_state_dict(state_dict)
        self.model.to(device)

        if steerer_type == 'learned':
            self.steerer = torch.nn.Linear(64, 64, bias=False)
            self.steerer.weight.data = torch.load(self.steerer_weights_path, map_location='cpu')['weight'][..., 0, 0]
            self.steerer.eval()
            self.steerer.to(device)
        else:
            self.steer_permutations = [perm.to(device) for perm in self.steer_permutations]

        self.max_num_keypoints = max_num_keypoints
        self.mode = mode
        self.min_cossim = kwargs.get("min_cossim", 0.8 if steerer_type == "learned" else 0.9)

    def download_weights(self):
        if self.steerer_type == "perm":
            self.weights_path = self.perm_weights_path
            if not self.perm_weights_path.exists():
                download_file_from_google_drive(self.perm_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.perm_weights_path.name)

        if self.steerer_type == "learned":
            self.weights_path = self.learned_weights_path
            if not self.learned_weights_path.exists():
                download_file_from_google_drive(self.learned_weights_gdrive_id, root=WEIGHTS_DIR, filename=self.learned_weights_path.name)
            if not self.steerer_weights_path.exists():
                download_file_from_google_drive(self.steerer_weights_drive_id, root=WEIGHTS_DIR, filename=self.steerer_weights_path.name)

    def preprocess(self, img: torch.Tensor) -> torch.Tensor:
        img = self.model.parse_input(img)
        if self.device == 'cuda' and self.mode == 'semi-dense' and img.dtype == torch.uint8:
            img = img / 255 # cuda error in upsample_bilinear_2d_out_frame if img is ubyte
        return img

    def _forward(self, img0, img1):
        img0, img1 = self.preprocess(img0), self.preprocess(img1)

        if self.mode == "semi-dense":
            output0 = self.model.detectAndComputeDense(img0, top_k=self.max_num_keypoints)
            output1 = self.model.detectAndComputeDense(img1, top_k=self.max_num_keypoints)

            rot0to1 = 0
            idxs_list = self.model.batch_match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
            descriptors0 = output0["descriptors"].clone()
            for r in range(1, 4):
                if self.steerer_type == "learned":
                    descriptors0 = torch.nn.functional.normalize(self.steerer(descriptors0), dim=-1)
                else:
                    descriptors0 = output0["descriptors"][..., self.steer_permutations[r]]

                new_idxs_list = self.model.batch_match(
                    descriptors0,
                    output1["descriptors"],
                    min_cossim=self.min_cossim
                )
                if len(new_idxs_list[0][0]) > len(idxs_list[0][0]):
                    idxs_list = new_idxs_list
                    rot0to1 = r

            # align to first image for refinement MLP
            if self.steerer_type == "learned":
                if rot0to1 > 0:
                    for _ in range(4 - rot0to1):
                        output1['descriptors'] = self.steerer(output1['descriptors'])  # Adding normalization here hurts performance for some reason, probably due to the way it's done during training
            else:
                output1["descriptors"] = output1["descriptors"][..., self.steer_permutations[-rot0to1]]

            matches = self.model.refine_matches(output0, output1, matches=idxs_list, batch_idx=0)
            mkpts0, mkpts1 = matches[:, :2], matches[:, 2:]

        else:
            output0 = self.model.detectAndCompute(img0, top_k=self.max_num_keypoints)[0]
            output1 = self.model.detectAndCompute(img1, top_k=self.max_num_keypoints)[0]

            idxs0, idxs1 = self.model.match(output0["descriptors"], output1["descriptors"], min_cossim=self.min_cossim)
            rot0to1 = 0
            for r in range(1, 4):
                if self.steerer_type == "learned":
                    output0['descriptors'] = torch.nn.functional.normalize(self.steerer(output0['descriptors']), dim=-1)
                    output0_steered_descriptors = output0['descriptors']
                else:
                    output0_steered_descriptors = output0['descriptors'][..., self.steer_permutations[r]]

                new_idxs0, new_idxs1 = self.model.match(
                    output0_steered_descriptors,
                    output1['descriptors'],
                    min_cossim=self.min_cossim
                )
                if len(new_idxs0) > len(idxs0):
                    idxs0 = new_idxs0
                    idxs1 = new_idxs1
                    rot0to1 = r

            mkpts0, mkpts1 = output0["keypoints"][idxs0], output1["keypoints"][idxs1]

        return (
            mkpts0,
            mkpts1,
            output0["keypoints"].squeeze(),
            output1["keypoints"].squeeze(),
            output0["descriptors"].squeeze(),
            output1["descriptors"].squeeze(),
        )