Pawel Piwowarski
init commit
0a82b18
import torch
import py3_wget
import gdown
from kornia.color import rgb_to_grayscale
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import load_module, add_to_path
BASE_PATH = THIRD_PARTY_DIR.joinpath("gim")
add_to_path(BASE_PATH)
from dkm.models.model_zoo.DKMv3 import DKMv3
class GIM_DKM(BaseMatcher):
weights_src = "https://drive.google.com/file/d/1gk97V4IROnR1Nprq10W9NCFUv2mxXR_-/view"
def __init__(self, device="cpu", max_num_keypoints=5000, **kwargs):
super().__init__(device, **kwargs)
self.ckpt_path = WEIGHTS_DIR / "gim_dkm_100h.ckpt"
self.model = DKMv3(weights=None, h=672, w=896)
self.max_num_keypoints = max_num_keypoints
self.download_weights()
self.load_weights()
self.model = self.model.eval().to(device)
def download_weights(self):
if not self.ckpt_path.exists():
print(f"Downloading {self.ckpt_path.name}")
gdown.download(GIM_DKM.weights_src, output=str(self.ckpt_path), fuzzy=True)
def load_weights(self):
state_dict = torch.load(self.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
for k in list(state_dict.keys()):
if k.startswith("model."):
state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
if "encoder.net.fc" in k:
state_dict.pop(k)
self.model.load_state_dict(state_dict)
def preprocess(self, img):
# this version of DKM requires PIL images as input
# return Image.fromarray(np.uint8(255*tensor_to_image(img)))
if img.ndim < 4:
img = img.unsqueeze(0)
return img
def _forward(self, img0, img1):
height0, width0 = img0.shape[-2:]
height1, width1 = img1.shape[-2:]
img0 = self.preprocess(img0) # now as PIL img
img1 = self.preprocess(img1) # now as PIL img
dense_matches, dense_certainty = self.model.match(img0, img1, device=self.device)
torch.cuda.empty_cache()
# sample matching keypoints from dense warp
sparse_matches, mconf = self.model.sample(dense_matches, dense_certainty, self.max_num_keypoints)
torch.cuda.empty_cache()
mkpts0 = sparse_matches[:, :2]
mkpts1 = sparse_matches[:, 2:]
# convert to px coords
mkpts0 = torch.stack(
(width0 * (mkpts0[:, 0] + 1) / 2, height0 * (mkpts0[:, 1] + 1) / 2),
dim=-1,
)
mkpts1 = torch.stack(
(width1 * (mkpts1[:, 0] + 1) / 2, height1 * (mkpts1[:, 1] + 1) / 2),
dim=-1,
)
# b_ids = torch.where(mconf[None])[0]
return mkpts0, mkpts1, None, None, None, None
class GIM_LG(BaseMatcher):
weights_src = "https://github.com/xuelunshen/gim/blob/main/weights/gim_lightglue_100h.ckpt"
superpoint_v1_weight_src = (
"https://github.com/magicleap/SuperGluePretrainedNetwork/raw/master/models/weights/superpoint_v1.pth"
)
def __init__(self, device="cpu", max_keypoints=2048, **kwargs):
super().__init__(device, **kwargs)
# load the altered version of gluefactory
load_module("gluefactory_gim", BASE_PATH.joinpath("gluefactory/__init__.py"))
from gluefactory_gim.superpoint import SuperPoint
from gluefactory_gim.models.matchers.lightglue import LightGlue
self.ckpt_path = BASE_PATH / "weights" / "gim_lightglue_100h.ckpt"
self.superpoint_v1_path = BASE_PATH / "weights" / "superpoint_v1.pth"
self.download_weights()
self.detector = SuperPoint(
{
"max_num_keypoints": max_keypoints,
"force_num_keypoints": True,
"detection_threshold": 0.0,
"nms_radius": 3,
"trainable": False,
}
)
self.model = LightGlue(
{
"filter_threshold": 0.1,
"flash": False,
"checkpointed": True,
}
)
self.load_weights()
def download_weights(self):
if not self.ckpt_path.exists():
print(f"Downloading {self.ckpt_path.name}")
py3_wget.download_file(GIM_LG.weights_src, self.ckpt_path)
if not self.superpoint_v1_path.exists():
print(f"Downloading {self.superpoint_v1_path.name}")
py3_wget.download_file(GIM_LG.superpoint_v1_weight_src, self.superpoint_v1_path)
def load_weights(self):
state_dict = torch.load(self.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
for k in list(state_dict.keys()):
if k.startswith("model."):
state_dict.pop(k)
if k.startswith("superpoint."):
state_dict[k.replace("superpoint.", "", 1)] = state_dict.pop(k)
self.detector.load_state_dict(state_dict)
state_dict = torch.load(self.ckpt_path, map_location="cpu")
if "state_dict" in state_dict.keys():
state_dict = state_dict["state_dict"]
for k in list(state_dict.keys()):
if k.startswith("superpoint."):
state_dict.pop(k)
if k.startswith("model."):
state_dict[k.replace("model.", "", 1)] = state_dict.pop(k)
self.model.load_state_dict(state_dict)
self.detector = self.detector.eval().to(self.device)
self.model = self.model.eval().to(self.device)
def preprocess(self, img):
# convert to grayscale
return rgb_to_grayscale(img.unsqueeze(0))
def _forward(self, img0, img1):
img0 = self.preprocess(img0)
img1 = self.preprocess(img1)
data = dict(image0=img0, image1=img1)
scale0 = torch.tensor([1.0, 1.0]).to(self.device)[None]
scale1 = torch.tensor([1.0, 1.0]).to(self.device)[None]
size0 = torch.tensor(data["image0"].shape[-2:][::-1])[None]
size1 = torch.tensor(data["image1"].shape[-2:][::-1])[None]
data.update(dict(size0=size0, size1=size1))
data.update(dict(scale0=scale0, scale1=scale1))
pred = {}
pred.update(
{
k + "0": v
for k, v in self.detector(
{
"image": data["image0"],
"image_size": data["size0"],
}
).items()
}
)
pred.update(
{
k + "1": v
for k, v in self.detector(
{
"image": data["image1"],
"image_size": data["size1"],
}
).items()
}
)
pred.update(self.model({**pred, **data, **{"resize0": data["size0"], "resize1": data["size1"]}}))
kpts0 = torch.cat([kp * s for kp, s in zip(pred["keypoints0"], data["scale0"][:, None])])
kpts1 = torch.cat([kp * s for kp, s in zip(pred["keypoints1"], data["scale1"][:, None])])
desc0, desc1 = pred["descriptors0"], pred["descriptors1"]
m_bids = torch.nonzero(pred["keypoints0"].sum(dim=2) > -1)[:, 0]
matches = pred["matches"]
bs = data["image0"].size(0)
mkpts0 = torch.cat([kpts0[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)])
mkpts1 = torch.cat([kpts1[m_bids == b_id][matches[b_id][..., 1]] for b_id in range(bs)])
# b_ids = torch.cat([m_bids[m_bids == b_id][matches[b_id][..., 0]] for b_id in range(bs)])
# mconf = torch.cat(pred['scores'])
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1