File size: 3,739 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
from pathlib import Path
import gdown
import torchvision.transforms as tfm

from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import to_numpy, resize_to_divisible, add_to_path

add_to_path(THIRD_PARTY_DIR.joinpath("XoFTR"), insert=0)

from src.xoftr import XoFTR
from src.config.default import get_cfg_defaults
from src.utils.misc import lower_config


class XoFTRMatcher(BaseMatcher):
    weights_src_640 = (
        "https://drive.google.com/file/d/1oRkEGsLpPIxlulc6a7c2q5H1XTNbfkVj/view"
    )
    weights_src_840 = (
        "https://drive.google.com/file/d/1bexiQGcbZWESb2lp1cTsa7r5al9M4xzj/view"
    )

    model_path_640 = WEIGHTS_DIR.joinpath("weights_xoftr_640.ckpt")
    model_path_840 = WEIGHTS_DIR.joinpath("weights_xoftr_840.ckpt")

    divisible_size = 8

    def __init__(self, device="cpu", pretrained_size=640, **kwargs):
        super().__init__(device, **kwargs)

        self.pretrained_size = pretrained_size
        assert self.pretrained_size in [
            640,
            840,
        ], f"Pretrained size must be in [640, 840], you entered {self.pretrained_size}"

        self.download_weights()

        self.matcher = self.build_matcher(**kwargs)

    def build_matcher(self, coarse_thresh=0.3, fine_thresh=0.1, denser=False):
        # Get default configurations
        config = get_cfg_defaults(inference=True)
        config = lower_config(config)

        # Coarse  & fine level thresholds
        config["xoftr"]["match_coarse"]["thr"] = coarse_thresh  # Default 0.3
        config["xoftr"]["fine"]["thr"] = fine_thresh  # Default 0.1

        # It is possible to get denser matches
        # If True, xoftr returns all fine-level matches for each fine-level window (at 1/2 resolution)
        config["xoftr"]["fine"]["denser"] = denser  # Default False

        matcher = XoFTR(config=config["xoftr"])

        ckpt = (
            self.model_path_640 if self.pretrained_size == 640 else self.model_path_840
        )

        # Load model
        matcher.load_state_dict(
            torch.load(ckpt, map_location="cpu")["state_dict"], strict=True
        )

        return matcher.eval().to(self.device)

    def download_weights(self):
        if (
            not Path(XoFTRMatcher.model_path_640).is_file()
            and self.pretrained_size == 640
        ):
            print("Downloading XoFTR outdoor... (takes a while)")
            gdown.download(
                XoFTRMatcher.weights_src_640,
                output=str(XoFTRMatcher.model_path_640),
                fuzzy=True,
            )

        if (
            not Path(XoFTRMatcher.model_path_840).is_file()
            and self.pretrained_size == 840
        ):
            print("Downloading XoFTR outdoor... (takes a while)")
            gdown.download(
                XoFTRMatcher.weights_src_840,
                output=str(XoFTRMatcher.model_path_840),
                fuzzy=True,
            )

    def preprocess(self, img):
        _, h, w = img.shape
        orig_shape = h, w
        img = resize_to_divisible(img, self.divisible_size)
        return tfm.Grayscale()(img).unsqueeze(0), orig_shape

    def _forward(self, img0, img1):
        img0, img0_orig_shape = self.preprocess(img0)
        img1, img1_orig_shape = self.preprocess(img1)

        batch = {"image0": img0, "image1": img1}

        self.matcher(batch)

        mkpts0 = to_numpy(batch["mkpts0_f"])
        mkpts1 = to_numpy(batch["mkpts1_f"])

        H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
        mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
        mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)

        return mkpts0, mkpts1, None, None, None, None