Pawel Piwowarski
init commit
0a82b18
import torch
from pathlib import Path
import gdown
import torchvision.transforms as tfm
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import to_numpy, resize_to_divisible, add_to_path
add_to_path(THIRD_PARTY_DIR.joinpath("XoFTR"), insert=0)
from src.xoftr import XoFTR
from src.config.default import get_cfg_defaults
from src.utils.misc import lower_config
class XoFTRMatcher(BaseMatcher):
weights_src_640 = (
"https://drive.google.com/file/d/1oRkEGsLpPIxlulc6a7c2q5H1XTNbfkVj/view"
)
weights_src_840 = (
"https://drive.google.com/file/d/1bexiQGcbZWESb2lp1cTsa7r5al9M4xzj/view"
)
model_path_640 = WEIGHTS_DIR.joinpath("weights_xoftr_640.ckpt")
model_path_840 = WEIGHTS_DIR.joinpath("weights_xoftr_840.ckpt")
divisible_size = 8
def __init__(self, device="cpu", pretrained_size=640, **kwargs):
super().__init__(device, **kwargs)
self.pretrained_size = pretrained_size
assert self.pretrained_size in [
640,
840,
], f"Pretrained size must be in [640, 840], you entered {self.pretrained_size}"
self.download_weights()
self.matcher = self.build_matcher(**kwargs)
def build_matcher(self, coarse_thresh=0.3, fine_thresh=0.1, denser=False):
# Get default configurations
config = get_cfg_defaults(inference=True)
config = lower_config(config)
# Coarse & fine level thresholds
config["xoftr"]["match_coarse"]["thr"] = coarse_thresh # Default 0.3
config["xoftr"]["fine"]["thr"] = fine_thresh # Default 0.1
# It is possible to get denser matches
# If True, xoftr returns all fine-level matches for each fine-level window (at 1/2 resolution)
config["xoftr"]["fine"]["denser"] = denser # Default False
matcher = XoFTR(config=config["xoftr"])
ckpt = (
self.model_path_640 if self.pretrained_size == 640 else self.model_path_840
)
# Load model
matcher.load_state_dict(
torch.load(ckpt, map_location="cpu")["state_dict"], strict=True
)
return matcher.eval().to(self.device)
def download_weights(self):
if (
not Path(XoFTRMatcher.model_path_640).is_file()
and self.pretrained_size == 640
):
print("Downloading XoFTR outdoor... (takes a while)")
gdown.download(
XoFTRMatcher.weights_src_640,
output=str(XoFTRMatcher.model_path_640),
fuzzy=True,
)
if (
not Path(XoFTRMatcher.model_path_840).is_file()
and self.pretrained_size == 840
):
print("Downloading XoFTR outdoor... (takes a while)")
gdown.download(
XoFTRMatcher.weights_src_840,
output=str(XoFTRMatcher.model_path_840),
fuzzy=True,
)
def preprocess(self, img):
_, h, w = img.shape
orig_shape = h, w
img = resize_to_divisible(img, self.divisible_size)
return tfm.Grayscale()(img).unsqueeze(0), orig_shape
def _forward(self, img0, img1):
img0, img0_orig_shape = self.preprocess(img0)
img1, img1_orig_shape = self.preprocess(img1)
batch = {"image0": img0, "image1": img1}
self.matcher(batch)
mkpts0 = to_numpy(batch["mkpts0_f"])
mkpts1 = to_numpy(batch["mkpts1_f"])
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
return mkpts0, mkpts1, None, None, None, None