|
import torchvision.transforms as tfm |
|
import torch |
|
import os |
|
import gdown |
|
|
|
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher |
|
from matching.utils import to_numpy, resize_to_divisible, lower_config, add_to_path |
|
|
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("Se2_LoFTR"), insert=0) |
|
from src.loftr.loftr import LoFTR |
|
from configs.loftr.outdoor.loftr_ds_e2_dense_8rot import cfg as rot8_cfg |
|
from configs.loftr.outdoor.loftr_ds_e2_dense_big import cfg as big_cfg |
|
from configs.loftr.outdoor.loftr_ds_e2_dense import cfg as e2dense_cfg |
|
from configs.loftr.outdoor.loftr_ds_e2 import cfg as e2_cfg |
|
|
|
|
|
class Se2LoFTRMatcher(BaseMatcher): |
|
|
|
configs = { |
|
"rot8": rot8_cfg, |
|
"big": big_cfg, |
|
"dense": e2dense_cfg, |
|
"rot4": e2_cfg, |
|
|
|
} |
|
|
|
weights = { |
|
"rot8": "se2loftr_rot8.pt", |
|
"big": "se2loftr_rot4_big.pt", |
|
"dense": "se2loftr_rot4_dense.pt", |
|
"rot4": "se2loftr_rot4.pt", |
|
|
|
} |
|
|
|
weights_url = { |
|
|
|
"rot8": "https://drive.google.com/file/d/1ulaJE25hMOYYxZsnPgLQXPqGFQv_06-O/view", |
|
"big": "https://drive.google.com/file/d/145i4KqbyCg6J1JdJTa0A05jVp_7ckebq/view", |
|
"dense": "https://drive.google.com/file/d/1QMDgOzhIB5zjm-K5Sltcpq7wF94ZpwE7/view", |
|
"rot4": "https://drive.google.com/file/d/19c00PuTtbQO4KxVod3G0FBr_MWrqts4c/view", |
|
|
|
|
|
|
|
|
|
|
|
|
|
} |
|
|
|
divisible_size = 32 |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=0, loftr_config="rot8", *args, **kwargs) -> None: |
|
super().__init__(device) |
|
assert loftr_config in self.configs.keys(), f"Config not found. Must choose from {self.configs.keys()}" |
|
self.loftr_config = loftr_config |
|
|
|
self.weights_path = WEIGHTS_DIR.joinpath(Se2LoFTRMatcher.weights[self.loftr_config]) |
|
|
|
self.download_weights() |
|
|
|
self.model = self.load_model(self.loftr_config, device) |
|
|
|
def download_weights(self): |
|
if not os.path.isfile(self.weights_path): |
|
print(f"Downloading {Se2LoFTRMatcher.weights_url[self.loftr_config]}") |
|
gdown.download( |
|
Se2LoFTRMatcher.weights_url[self.loftr_config], |
|
output=str(self.weights_path), |
|
fuzzy=True, |
|
) |
|
|
|
def load_model(self, config, device="cpu"): |
|
model = LoFTR(config=lower_config(Se2LoFTRMatcher.configs[config])["loftr"]).to(self.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
print(str(self.weights_path)) |
|
model.load_state_dict(torch.load(str(self.weights_path), map_location=device)) |
|
return model.eval() |
|
|
|
def preprocess(self, img): |
|
|
|
_, h, w = img.shape |
|
orig_shape = h, w |
|
img = resize_to_divisible(img, self.divisible_size) |
|
return tfm.Grayscale()(img).unsqueeze(0), orig_shape |
|
|
|
def _forward(self, img0, img1): |
|
img0, img0_orig_shape = self.preprocess(img0) |
|
img1, img1_orig_shape = self.preprocess(img1) |
|
|
|
batch = {"image0": img0, "image1": img1} |
|
self.model(batch) |
|
|
|
|
|
mkpts0 = to_numpy(batch["mkpts0_f"]) |
|
mkpts1 = to_numpy(batch["mkpts1_f"]) |
|
|
|
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:] |
|
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) |
|
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) |
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|