Pawel Piwowarski
init commit
0a82b18
import os
import torch
import py3_wget
from matching.im_models.lightglue import SIFT, SuperPoint
from matching.utils import add_to_path
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
add_to_path(THIRD_PARTY_DIR.joinpath('SphereGlue'))
from model.sphereglue import SphereGlue
from utils.Utils import sphericalToCartesian
def unit_cartesian(points):
phi, theta = torch.split(torch.as_tensor(points), 1, dim=1)
unitCartesian = sphericalToCartesian(phi, theta, 1).squeeze(dim=2)
return unitCartesian
class SphereGlueBase(BaseMatcher):
"""
This class is the parent for all methods that use LightGlue as a matcher,
with different local features. It implements the forward which is the same
regardless of the feature extractor of choice.
Therefore this class should *NOT* be instatiated, as it needs its children to define
the extractor and the matcher.
"""
def __init__(self, device="cpu", **kwargs):
super().__init__(device, **kwargs)
self.sphereglue_cfg = {
"K": kwargs.get("K", 2),
"GNN_layers": kwargs.get("GNN_layers", ["cross"]),
"match_threshold": kwargs.get("match_threshold", 0.2),
"sinkhorn_iterations": kwargs.get("sinkhorn_iterations", 20),
"aggr": kwargs.get("aggr", "add"),
"knn": kwargs.get("knn", 20),
}
self.skip_ransac = True
def download_weights(self):
if not os.path.isfile(self.model_path):
print("Downloading SphereGlue weights")
py3_wget.download_file(self.weights_url, self.model_path)
def _forward(self, img0, img1):
"""
"extractor" and "matcher" are instantiated by the subclasses.
"""
feats0 = self.extractor.extract(img0)
feats1 = self.extractor.extract(img1)
unit_cartesian1 = unit_cartesian(feats0["keypoints"][0]).unsqueeze(dim=0).to(self.device)
unit_cartesian2 = unit_cartesian(feats1["keypoints"][0]).unsqueeze(dim=0).to(self.device)
inputs = {
"h1": feats0["descriptors"],
"h2": feats1["descriptors"],
"scores1": feats0["keypoint_scores"],
"scores2": feats1["keypoint_scores"],
"unitCartesian1": unit_cartesian1,
"unitCartesian2": unit_cartesian2,
}
outputs = self.matcher(inputs)
kpts0, kpts1, matches = (
feats0["keypoints"].squeeze(dim=0),
feats1["keypoints"].squeeze(dim=0),
outputs["matches0"].squeeze(dim=0),
)
desc0 = feats0["descriptors"].squeeze(dim=0)
desc1 = feats1["descriptors"].squeeze(dim=0)
mask = matches.ge(0)
kpts0_idx = torch.masked_select(torch.arange(matches.shape[0]).to(mask.device), mask)
kpts1_idx = torch.masked_select(matches, mask)
mkpts0 = kpts0[kpts0_idx]
mkpts1 = kpts1[kpts1_idx]
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1
class SiftSphereGlue(SphereGlueBase):
model_path = WEIGHTS_DIR.joinpath("sift-sphereglue.pt")
weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/sift/autosaved.pt"
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
super().__init__(device, **kwargs)
self.download_weights()
self.sphereglue_cfg.update({
"descriptor_dim": 128,
"output_dim": 128*2,
"max_kpts": max_num_keypoints
})
self.extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device)
self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"])
class SuperpointSphereGlue(SphereGlueBase):
model_path = WEIGHTS_DIR.joinpath("superpoint-sphereglue.pt")
weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/superpoint/autosaved.pt"
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
super().__init__(device, **kwargs)
self.download_weights()
self.sphereglue_cfg.update({
"descriptor_dim": 256,
"output_dim": 256*2,
"max_kpts": max_num_keypoints
})
self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(self.device)
self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device)
self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"])