|
import torch |
|
import py3_wget |
|
|
|
import gdown |
|
from kornia.color import rgb_to_grayscale |
|
|
|
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher |
|
from matching.utils import load_module, add_to_path |
|
|
|
BASE_PATH = THIRD_PARTY_DIR.joinpath("gim") |
|
add_to_path(BASE_PATH) |
|
from dkm.models.model_zoo.DKMv3 import DKMv3 |
|
|
|
|
|
class GIM_DKM(BaseMatcher): |
|
|
|
weights_src = "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view" |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=5000, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.ckpt_path = WEIGHTS_DIR / "gim_dkm_100h.ckpt" |
|
|
|
self.model = DKMv3(weights=None, h=672, w=896) |
|
|
|
self.max_num_keypoints = max_num_keypoints |
|
|
|
self.download_weights() |
|
self.load_weights() |
|
|
|
self.model = self.model.eval().to(device) |
|
|
|
def download_weights(self): |
|
if not self.ckpt_path.exists(): |
|
print(f"Downloading {self.ckpt_path.name}") |
|
gdown.download(GIM_DKM.weights_src, output=str(self.ckpt_path), fuzzy=True) |
|
|
|
def load_weights(self): |
|
state_dict = torch.load(self.ckpt_path, map_location="cpu") |
|
if "state_dict" in state_dict.keys(): |
|
state_dict = state_dict["state_dict"] |
|
for k in list(state_dict.keys()): |
|
if k.startswith("model."): |
|
state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) |
|
if "encoder.net.fc" in k: |
|
state_dict.pop(k) |
|
self.model.load_state_dict(state_dict) |
|
|
|
def preprocess(self, img): |
|
|
|
|
|
if img.ndim < 4: |
|
img = img.unsqueeze(0) |
|
return img |
|
|
|
def _forward(self, img0, img1): |
|
height0, width0 = img0.shape[-2:] |
|
height1, width1 = img1.shape[-2:] |
|
|
|
img0 = self.preprocess(img0) |
|
img1 = self.preprocess(img1) |
|
dense_matches, dense_certainty = self.model.match(img0, img1, device=self.device) |
|
torch.cuda.empty_cache() |
|
|
|
sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, self.max_num_keypoints) |
|
torch.cuda.empty_cache() |
|
mkpts0 = sparse_matches[:, :2] |
|
mkpts1 = sparse_matches[:, 2:] |
|
|
|
|
|
mkpts0 = torch.stack( |
|
(width0 * (mkpts0[:, 0] + 1) / 2, height0 * (mkpts0[:, 1] + 1) / 2), |
|
dim=-1, |
|
) |
|
mkpts1 = torch.stack( |
|
(width1 * (mkpts1[:, 0] + 1) / 2, height1 * (mkpts1[:, 1] + 1) / 2), |
|
dim=-1, |
|
) |
|
|
|
|
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|
|
|
|
class GIM_LG(BaseMatcher): |
|
|
|
weights_src = "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt" |
|
superpoint_v1_weight_src = ( |
|
"https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth" |
|
) |
|
|
|
def __init__(self, device="cpu", max_keypoints=2048, **kwargs): |
|
super().__init__(device, **kwargs) |
|
|
|
load_module("gluefactory_gim", BASE_PATH.joinpath("gluefactory/__init__.py")) |
|
|
|
from gluefactory_gim.superpoint import SuperPoint |
|
from gluefactory_gim.models.matchers.lightglue import LightGlue |
|
|
|
self.ckpt_path = BASE_PATH / "weights" / "gim_lightglue_100h.ckpt" |
|
self.superpoint_v1_path = BASE_PATH / "weights" / "superpoint_v1.pth" |
|
|
|
self.download_weights() |
|
|
|
self.detector = SuperPoint( |
|
{ |
|
"max_num_keypoints": max_keypoints, |
|
"force_num_keypoints": True, |
|
"detection_threshold": 0.0, |
|
"nms_radius": 3, |
|
"trainable": False, |
|
} |
|
) |
|
|
|
self.model = LightGlue( |
|
{ |
|
"filter_threshold": 0.1, |
|
"flash": False, |
|
"checkpointed": True, |
|
} |
|
) |
|
|
|
self.load_weights() |
|
|
|
def download_weights(self): |
|
if not self.ckpt_path.exists(): |
|
print(f"Downloading {self.ckpt_path.name}") |
|
py3_wget.download_file(GIM_LG.weights_src, self.ckpt_path) |
|
if not self.superpoint_v1_path.exists(): |
|
print(f"Downloading {self.superpoint_v1_path.name}") |
|
py3_wget.download_file(GIM_LG.superpoint_v1_weight_src, self.superpoint_v1_path) |
|
|
|
def load_weights(self): |
|
state_dict = torch.load(self.ckpt_path, map_location="cpu") |
|
if "state_dict" in state_dict.keys(): |
|
state_dict = state_dict["state_dict"] |
|
for k in list(state_dict.keys()): |
|
if k.startswith("model."): |
|
state_dict.pop(k) |
|
if k.startswith("superpoint."): |
|
state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k) |
|
self.detector.load_state_dict(state_dict) |
|
|
|
state_dict = torch.load(self.ckpt_path, map_location="cpu") |
|
if "state_dict" in state_dict.keys(): |
|
state_dict = state_dict["state_dict"] |
|
for k in list(state_dict.keys()): |
|
if k.startswith("superpoint."): |
|
state_dict.pop(k) |
|
if k.startswith("model."): |
|
state_dict[k.replace("model.", "", 1)] = state_dict.pop(k) |
|
self.model.load_state_dict(state_dict) |
|
|
|
self.detector = self.detector.eval().to(self.device) |
|
self.model = self.model.eval().to(self.device) |
|
|
|
def preprocess(self, img): |
|
|
|
return rgb_to_grayscale(img.unsqueeze(0)) |
|
|
|
def _forward(self, img0, img1): |
|
img0 = self.preprocess(img0) |
|
img1 = self.preprocess(img1) |
|
|
|
data = dict(image0=img0, image1=img1) |
|
|
|
scale0 = torch.tensor([1.0, 1.0]).to(self.device)[None] |
|
scale1 = torch.tensor([1.0, 1.0]).to(self.device)[None] |
|
|
|
size0 = torch.tensor(data["image0"].shape[-2:][::-1])[None] |
|
size1 = torch.tensor(data["image1"].shape[-2:][::-1])[None] |
|
|
|
data.update(dict(size0=size0, size1=size1)) |
|
data.update(dict(scale0=scale0, scale1=scale1)) |
|
|
|
pred = {} |
|
pred.update( |
|
{ |
|
k + "0": v |
|
for k, v in self.detector( |
|
{ |
|
"image": data["image0"], |
|
"image_size": data["size0"], |
|
} |
|
).items() |
|
} |
|
) |
|
pred.update( |
|
{ |
|
k + "1": v |
|
for k, v in self.detector( |
|
{ |
|
"image": data["image1"], |
|
"image_size": data["size1"], |
|
} |
|
).items() |
|
} |
|
) |
|
|
|
pred.update(self.model({**pred, **data, **{"resize0": data["size0"], "resize1": data["size1"]}})) |
|
|
|
kpts0 = torch.cat([kp * s for kp, s in zip(pred["keypoints0"], data["scale0"][:, None])]) |
|
kpts1 = torch.cat([kp * s for kp, s in zip(pred["keypoints1"], data["scale1"][:, None])]) |
|
|
|
desc0, desc1 = pred["descriptors0"], pred["descriptors1"] |
|
|
|
m_bids = torch.nonzero(pred["keypoints0"].sum(dim=2) > -1)[:, 0] |
|
matches = pred["matches"] |
|
bs = data["image0"].size(0) |
|
|
|
mkpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)]) |
|
mkpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)]) |
|
|
|
|
|
|
|
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1 |
|
|