|
import os |
|
import torch |
|
import py3_wget |
|
|
|
from matching.im_models.lightglue import SIFT, SuperPoint |
|
from matching.utils import add_to_path |
|
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath('SphereGlue')) |
|
|
|
from model.sphereglue import SphereGlue |
|
from utils.Utils import sphericalToCartesian |
|
|
|
|
|
def unit_cartesian(points): |
|
phi, theta = torch.split(torch.as_tensor(points), 1, dim=1) |
|
unitCartesian = sphericalToCartesian(phi, theta, 1).squeeze(dim=2) |
|
return unitCartesian |
|
|
|
|
|
class SphereGlueBase(BaseMatcher): |
|
""" |
|
This class is the parent for all methods that use LightGlue as a matcher, |
|
with different local features. It implements the forward which is the same |
|
regardless of the feature extractor of choice. |
|
Therefore this class should *NOT* be instatiated, as it needs its children to define |
|
the extractor and the matcher. |
|
""" |
|
|
|
def __init__(self, device="cpu", **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.sphereglue_cfg = { |
|
"K": kwargs.get("K", 2), |
|
"GNN_layers": kwargs.get("GNN_layers", ["cross"]), |
|
"match_threshold": kwargs.get("match_threshold", 0.2), |
|
"sinkhorn_iterations": kwargs.get("sinkhorn_iterations", 20), |
|
"aggr": kwargs.get("aggr", "add"), |
|
"knn": kwargs.get("knn", 20), |
|
} |
|
|
|
self.skip_ransac = True |
|
|
|
def download_weights(self): |
|
if not os.path.isfile(self.model_path): |
|
print("Downloading SphereGlue weights") |
|
py3_wget.download_file(self.weights_url, self.model_path) |
|
|
|
def _forward(self, img0, img1): |
|
""" |
|
"extractor" and "matcher" are instantiated by the subclasses. |
|
""" |
|
feats0 = self.extractor.extract(img0) |
|
feats1 = self.extractor.extract(img1) |
|
|
|
unit_cartesian1 = unit_cartesian(feats0["keypoints"][0]).unsqueeze(dim=0).to(self.device) |
|
unit_cartesian2 = unit_cartesian(feats1["keypoints"][0]).unsqueeze(dim=0).to(self.device) |
|
|
|
inputs = { |
|
"h1": feats0["descriptors"], |
|
"h2": feats1["descriptors"], |
|
"scores1": feats0["keypoint_scores"], |
|
"scores2": feats1["keypoint_scores"], |
|
"unitCartesian1": unit_cartesian1, |
|
"unitCartesian2": unit_cartesian2, |
|
} |
|
outputs = self.matcher(inputs) |
|
|
|
kpts0, kpts1, matches = ( |
|
feats0["keypoints"].squeeze(dim=0), |
|
feats1["keypoints"].squeeze(dim=0), |
|
outputs["matches0"].squeeze(dim=0), |
|
) |
|
desc0 = feats0["descriptors"].squeeze(dim=0) |
|
desc1 = feats1["descriptors"].squeeze(dim=0) |
|
|
|
mask = matches.ge(0) |
|
kpts0_idx = torch.masked_select(torch.arange(matches.shape[0]).to(mask.device), mask) |
|
kpts1_idx = torch.masked_select(matches, mask) |
|
mkpts0 = kpts0[kpts0_idx] |
|
mkpts1 = kpts1[kpts1_idx] |
|
|
|
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1 |
|
|
|
|
|
class SiftSphereGlue(SphereGlueBase): |
|
model_path = WEIGHTS_DIR.joinpath("sift-sphereglue.pt") |
|
weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/sift/autosaved.pt" |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.download_weights() |
|
self.sphereglue_cfg.update({ |
|
"descriptor_dim": 128, |
|
"output_dim": 128*2, |
|
"max_kpts": max_num_keypoints |
|
}) |
|
self.extractor = SIFT(max_num_keypoints=max_num_keypoints).eval().to(self.device) |
|
self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device) |
|
self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"]) |
|
|
|
|
|
class SuperpointSphereGlue(SphereGlueBase): |
|
model_path = WEIGHTS_DIR.joinpath("superpoint-sphereglue.pt") |
|
weights_url = "https://github.com/vishalsharbidar/SphereGlue/raw/refs/heads/main/model_weights/superpoint/autosaved.pt" |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.download_weights() |
|
self.sphereglue_cfg.update({ |
|
"descriptor_dim": 256, |
|
"output_dim": 256*2, |
|
"max_kpts": max_num_keypoints |
|
}) |
|
self.extractor = SuperPoint(max_num_keypoints=max_num_keypoints).eval().to(self.device) |
|
self.matcher = SphereGlue(config=self.sphereglue_cfg).to(self.device) |
|
self.matcher.load_state_dict(torch.load(self.model_path, map_location=self.device)["MODEL_STATE_DICT"]) |