|
import torch |
|
from pathlib import Path |
|
import gdown |
|
import torchvision.transforms as tfm |
|
|
|
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher |
|
from matching.utils import to_numpy, resize_to_divisible, add_to_path |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("XoFTR"), insert=0) |
|
|
|
from src.xoftr import XoFTR |
|
from src.config.default import get_cfg_defaults |
|
from src.utils.misc import lower_config |
|
|
|
|
|
class XoFTRMatcher(BaseMatcher): |
|
weights_src_640 = ( |
|
"https://drive.google.com/file/d/1oRkEGsLpPIxlulc6a7c2q5H1XTNbfkVj/view" |
|
) |
|
weights_src_840 = ( |
|
"https://drive.google.com/file/d/1bexiQGcbZWESb2lp1cTsa7r5al9M4xzj/view" |
|
) |
|
|
|
model_path_640 = WEIGHTS_DIR.joinpath("weights_xoftr_640.ckpt") |
|
model_path_840 = WEIGHTS_DIR.joinpath("weights_xoftr_840.ckpt") |
|
|
|
divisible_size = 8 |
|
|
|
def __init__(self, device="cpu", pretrained_size=640, **kwargs): |
|
super().__init__(device, **kwargs) |
|
|
|
self.pretrained_size = pretrained_size |
|
assert self.pretrained_size in [ |
|
640, |
|
840, |
|
], f"Pretrained size must be in [640, 840], you entered {self.pretrained_size}" |
|
|
|
self.download_weights() |
|
|
|
self.matcher = self.build_matcher(**kwargs) |
|
|
|
def build_matcher(self, coarse_thresh=0.3, fine_thresh=0.1, denser=False): |
|
|
|
config = get_cfg_defaults(inference=True) |
|
config = lower_config(config) |
|
|
|
|
|
config["xoftr"]["match_coarse"]["thr"] = coarse_thresh |
|
config["xoftr"]["fine"]["thr"] = fine_thresh |
|
|
|
|
|
|
|
config["xoftr"]["fine"]["denser"] = denser |
|
|
|
matcher = XoFTR(config=config["xoftr"]) |
|
|
|
ckpt = ( |
|
self.model_path_640 if self.pretrained_size == 640 else self.model_path_840 |
|
) |
|
|
|
|
|
matcher.load_state_dict( |
|
torch.load(ckpt, map_location="cpu")["state_dict"], strict=True |
|
) |
|
|
|
return matcher.eval().to(self.device) |
|
|
|
def download_weights(self): |
|
if ( |
|
not Path(XoFTRMatcher.model_path_640).is_file() |
|
and self.pretrained_size == 640 |
|
): |
|
print("Downloading XoFTR outdoor... (takes a while)") |
|
gdown.download( |
|
XoFTRMatcher.weights_src_640, |
|
output=str(XoFTRMatcher.model_path_640), |
|
fuzzy=True, |
|
) |
|
|
|
if ( |
|
not Path(XoFTRMatcher.model_path_840).is_file() |
|
and self.pretrained_size == 840 |
|
): |
|
print("Downloading XoFTR outdoor... (takes a while)") |
|
gdown.download( |
|
XoFTRMatcher.weights_src_840, |
|
output=str(XoFTRMatcher.model_path_840), |
|
fuzzy=True, |
|
) |
|
|
|
def preprocess(self, img): |
|
_, h, w = img.shape |
|
orig_shape = h, w |
|
img = resize_to_divisible(img, self.divisible_size) |
|
return tfm.Grayscale()(img).unsqueeze(0), orig_shape |
|
|
|
def _forward(self, img0, img1): |
|
img0, img0_orig_shape = self.preprocess(img0) |
|
img1, img1_orig_shape = self.preprocess(img1) |
|
|
|
batch = {"image0": img0, "image1": img1} |
|
|
|
self.matcher(batch) |
|
|
|
mkpts0 = to_numpy(batch["mkpts0_f"]) |
|
mkpts1 = to_numpy(batch["mkpts1_f"]) |
|
|
|
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:] |
|
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) |
|
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) |
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|