File size: 4,626 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import os
import torch
import py3_wget

from matching.im_models.lightglue import SIFT, SuperPoint
from matching.utils import add_to_path
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher

add_to_path(THIRD_PARTY_DIR.joinpath('SphereGlue'))

from model.sphereglue import SphereGlue
from utils.Utils import sphericalToCartesian


def unit_cartesian(points):
    phi, theta =  torch.split(torch.as_tensor(points), 1, dim=1)
    unitCartesian = sphericalToCartesian(phi, theta, 1).squeeze(dim=2)
    return unitCartesian


class SphereGlueBase(BaseMatcher):
    """
    This class is the parent for all methods that use LightGlue as a matcher,
    with different local features. It implements the forward which is the same
    regardless of the feature extractor of choice.
    Therefore this class should *NOT* be instatiated, as it needs its children to define
    the extractor and the matcher.
    """

    def __init__(self, device="cpu", **kwargs):
        super().__init__(device, **kwargs)
        self.sphereglue_cfg = {
            "K": kwargs.get("K", 2),
            "GNN_layers": kwargs.get("GNN_layers", ["cross"]),
            "match_threshold": kwargs.get("match_threshold", 0.2),
            "sinkhorn_iterations": kwargs.get("sinkhorn_iterations", 20),
            "aggr": kwargs.get("aggr", "add"),
            "knn": kwargs.get("knn", 20),
        }

        self.skip_ransac = True

    def download_weights(self):
        if not os.path.isfile(self.model_path):
            print("Downloading SphereGlue weights")
            py3_wget.download_file(self.weights_url, self.model_path)

    def _forward(self, img0, img1):
        """
        "extractor" and "matcher" are instantiated by the subclasses.
        """
        feats0 = self.extractor.extract(img0)
        feats1 = self.extractor.extract(img1)

        unit_cartesian1 = unit_cartesian(feats0["keypoints"][0]).unsqueeze(dim=0).to(self.device)
        unit_cartesian2 = unit_cartesian(feats1["keypoints"][0]).unsqueeze(dim=0).to(self.device)

        inputs = {
            "h1": feats0["descriptors"],
            "h2": feats1["descriptors"],
            "scores1": feats0["keypoint_scores"],
            "scores2": feats1["keypoint_scores"],
            "unitCartesian1": unit_cartesian1,
            "unitCartesian2": unit_cartesian2,
        }
        outputs = self.matcher(inputs)

        kpts0, kpts1, matches = (
            feats0["keypoints"].squeeze(dim=0),
            feats1["keypoints"].squeeze(dim=0),
            outputs["matches0"].squeeze(dim=0),
        )
        desc0 = feats0["descriptors"].squeeze(dim=0)
        desc1 = feats1["descriptors"].squeeze(dim=0)

        mask = matches.ge(0)
        kpts0_idx = torch.masked_select(torch.arange(matches.shape[0]).to(mask.device), mask)
        kpts1_idx = torch.masked_select(matches, mask)
        mkpts0 = kpts0[kpts0_idx]
        mkpts1 = kpts1[kpts1_idx]

        return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1


class SiftSphereGlue(SphereGlueBase):
    model_path = WEIGHTS_DIR.joinpath("sift-sphereglue.pt")
    weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/sift/autosaved.pt"

    def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.download_weights()
        self.sphereglue_cfg.update({
            "descriptor_dim": 128,
            "output_dim": 128*2,
            "max_kpts": max_num_keypoints
        })
        self.extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(self.device)
        self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device)
        self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"])


class SuperpointSphereGlue(SphereGlueBase):
    model_path = WEIGHTS_DIR.joinpath("superpoint-sphereglue.pt")
    weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/superpoint/autosaved.pt"

    def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.download_weights()
        self.sphereglue_cfg.update({
            "descriptor_dim": 256,
            "output_dim": 256*2,
            "max_kpts": max_num_keypoints
        })
        self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(self.device)
        self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device)
        self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"])