Pawel Piwowarski
init commit
0a82b18
import torchvision.transforms as tfm
import torch
import os
import gdown
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import to_numpy, resize_to_divisible, lower_config, add_to_path
add_to_path(THIRD_PARTY_DIR.joinpath("Se2_LoFTR"), insert=0)
from src.loftr.loftr import LoFTR
from configs.loftr.outdoor.loftr_ds_e2_dense_8rot import cfg as rot8_cfg
from configs.loftr.outdoor.loftr_ds_e2_dense_big import cfg as big_cfg
from configs.loftr.outdoor.loftr_ds_e2_dense import cfg as e2dense_cfg
from configs.loftr.outdoor.loftr_ds_e2 import cfg as e2_cfg
class Se2LoFTRMatcher(BaseMatcher):
# dense and base loftr have shape mismatches in state dict load
configs = {
"rot8": rot8_cfg,
"big": big_cfg,
"dense": e2dense_cfg,
"rot4": e2_cfg,
# 'loftr': baseline_cfg
}
weights = {
"rot8": "se2loftr_rot8.pt",
"big": "se2loftr_rot4_big.pt",
"dense": "se2loftr_rot4_dense.pt",
"rot4": "se2loftr_rot4.pt",
# 'loftr': 'baseline.ckpt'
}
weights_url = {
# weight files (.pt) only
"rot8": "https://drive.google.com/file/d/1ulaJE25hMOYYxZsnPgLQXPqGFQv_06-O/view",
"big": "https://drive.google.com/file/d/145i4KqbyCg6J1JdJTa0A05jVp_7ckebq/view",
"dense": "https://drive.google.com/file/d/1QMDgOzhIB5zjm-K5Sltcpq7wF94ZpwE7/view",
"rot4": "https://drive.google.com/file/d/19c00PuTtbQO4KxVod3G0FBr_MWrqts4c/view",
# original ckpts (requires pytorch lightning to load)
# "rot8": "https://drive.google.com/file/d/1jPtOTxmwo1Z_YYP2YMS6efOevDaNiJR4/view",
# "big": "https://drive.google.com/file/d/1AE_EmmhQLfArIP-zokSlleY2YiSgBV3m/view",
# 'dense':'https://drive.google.com/file/d/17vxdnVtjVuq2m8qJsOG1JFfJjAqcgr4j/view',
# 'rot4': 'https://drive.google.com/file/d/17vxdnVtjVuq2m8qJsOG1JFfJjAqcgr4j/view'
# 'loftr': 'https://drive.google.com/file/d/1OylPSrbjzRJgvLHM3qJPAVpW3BEQeuFS/view'
}
divisible_size = 32
def __init__(self, device="cpu", max_num_keypoints=0, loftr_config="rot8", *args, **kwargs) -> None:
super().__init__(device)
assert loftr_config in self.configs.keys(), f"Config not found. Must choose from {self.configs.keys()}"
self.loftr_config = loftr_config
self.weights_path = WEIGHTS_DIR.joinpath(Se2LoFTRMatcher.weights[self.loftr_config])
self.download_weights()
self.model = self.load_model(self.loftr_config, device)
def download_weights(self):
if not os.path.isfile(self.weights_path):
print(f"Downloading {Se2LoFTRMatcher.weights_url[self.loftr_config]}")
gdown.download(
Se2LoFTRMatcher.weights_url[self.loftr_config],
output=str(self.weights_path),
fuzzy=True,
)
def load_model(self, config, device="cpu"):
model = LoFTR(config=lower_config(Se2LoFTRMatcher.configs[config])["loftr"]).to(self.device)
# model.load_state_dict(
# {
# k.replace("matcher.", ""): v
# for k, v in torch.load(self.weights_path, map_location=device)[
# "state_dict"
# ].items()
# }
# )
print(str(self.weights_path))
model.load_state_dict(torch.load(str(self.weights_path), map_location=device))
return model.eval()
def preprocess(self, img):
# loftr requires grayscale imgs divisible by 32
_, h, w = img.shape
orig_shape = h, w
img = resize_to_divisible(img, self.divisible_size)
return tfm.Grayscale()(img).unsqueeze(0), orig_shape
def _forward(self, img0, img1):
img0, img0_orig_shape = self.preprocess(img0)
img1, img1_orig_shape = self.preprocess(img1)
batch = {"image0": img0, "image1": img1}
self.model(batch) # loftr does not return anything, instead stores results in batch dict
# batch now has keys: ['mkpts0_f', 'mkpts1_f', 'expec_f','mkpts0_c', 'mkpts1_c', 'mconf', 'm_bids','gt_mask']
mkpts0 = to_numpy(batch["mkpts0_f"])
mkpts1 = to_numpy(batch["mkpts1_f"])
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
return mkpts0, mkpts1, None, None, None, None