File size: 4,464 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import torchvision.transforms as tfm
import torch
import os
import gdown

from matching import WEIGHTS_DIR, THIRD_PARTY_DIR, BaseMatcher
from matching.utils import to_numpy, resize_to_divisible, lower_config, add_to_path


add_to_path(THIRD_PARTY_DIR.joinpath("Se2_LoFTR"), insert=0)
from src.loftr.loftr import LoFTR
from configs.loftr.outdoor.loftr_ds_e2_dense_8rot import cfg as rot8_cfg
from configs.loftr.outdoor.loftr_ds_e2_dense_big import cfg as big_cfg
from configs.loftr.outdoor.loftr_ds_e2_dense import cfg as e2dense_cfg
from configs.loftr.outdoor.loftr_ds_e2 import cfg as e2_cfg


class Se2LoFTRMatcher(BaseMatcher):
    # dense and base loftr have shape mismatches in state dict load
    configs = {
        "rot8": rot8_cfg,
        "big": big_cfg,
        "dense": e2dense_cfg,
        "rot4": e2_cfg,
        # 'loftr': baseline_cfg
    }

    weights = {
        "rot8": "se2loftr_rot8.pt",
        "big": "se2loftr_rot4_big.pt",
        "dense": "se2loftr_rot4_dense.pt",
        "rot4": "se2loftr_rot4.pt",
        #    'loftr': 'baseline.ckpt'
    }

    weights_url = {
        # weight files (.pt) only
        "rot8": "https://drive.google.com/file/d/1ulaJE25hMOYYxZsnPgLQXPqGFQv_06-O/view",
        "big": "https://drive.google.com/file/d/145i4KqbyCg6J1JdJTa0A05jVp_7ckebq/view",
        "dense": "https://drive.google.com/file/d/1QMDgOzhIB5zjm-K5Sltcpq7wF94ZpwE7/view",
        "rot4": "https://drive.google.com/file/d/19c00PuTtbQO4KxVod3G0FBr_MWrqts4c/view",
        # original ckpts (requires pytorch lightning to load)
        # "rot8": "https://drive.google.com/file/d/1jPtOTxmwo1Z_YYP2YMS6efOevDaNiJR4/view",
        # "big": "https://drive.google.com/file/d/1AE_EmmhQLfArIP-zokSlleY2YiSgBV3m/view",
        # 'dense':'https://drive.google.com/file/d/17vxdnVtjVuq2m8qJsOG1JFfJjAqcgr4j/view',
        # 'rot4': 'https://drive.google.com/file/d/17vxdnVtjVuq2m8qJsOG1JFfJjAqcgr4j/view'
        # 'loftr': 'https://drive.google.com/file/d/1OylPSrbjzRJgvLHM3qJPAVpW3BEQeuFS/view'
    }

    divisible_size = 32

    def __init__(self, device="cpu", max_num_keypoints=0, loftr_config="rot8", *args, **kwargs) -> None:
        super().__init__(device)
        assert loftr_config in self.configs.keys(), f"Config not found. Must choose from {self.configs.keys()}"
        self.loftr_config = loftr_config

        self.weights_path = WEIGHTS_DIR.joinpath(Se2LoFTRMatcher.weights[self.loftr_config])

        self.download_weights()

        self.model = self.load_model(self.loftr_config, device)

    def download_weights(self):
        if not os.path.isfile(self.weights_path):
            print(f"Downloading {Se2LoFTRMatcher.weights_url[self.loftr_config]}")
            gdown.download(
                Se2LoFTRMatcher.weights_url[self.loftr_config],
                output=str(self.weights_path),
                fuzzy=True,
            )

    def load_model(self, config, device="cpu"):
        model = LoFTR(config=lower_config(Se2LoFTRMatcher.configs[config])["loftr"]).to(self.device)
        # model.load_state_dict(
        #     {
        #         k.replace("matcher.", ""): v
        #         for k, v in torch.load(self.weights_path, map_location=device)[
        #             "state_dict"
        #         ].items()
        #     }
        # )
        print(str(self.weights_path))
        model.load_state_dict(torch.load(str(self.weights_path), map_location=device))
        return model.eval()

    def preprocess(self, img):
        # loftr requires grayscale imgs divisible by 32
        _, h, w = img.shape
        orig_shape = h, w
        img = resize_to_divisible(img, self.divisible_size)
        return tfm.Grayscale()(img).unsqueeze(0), orig_shape

    def _forward(self, img0, img1):
        img0, img0_orig_shape = self.preprocess(img0)
        img1, img1_orig_shape = self.preprocess(img1)

        batch = {"image0": img0, "image1": img1}
        self.model(batch)  # loftr does not return anything, instead stores results in batch dict
        # batch now has keys: ['mkpts0_f', 'mkpts1_f', 'expec_f','mkpts0_c', 'mkpts1_c', 'mconf', 'm_bids','gt_mask']

        mkpts0 = to_numpy(batch["mkpts0_f"])
        mkpts1 = to_numpy(batch["mkpts1_f"])

        H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:]
        mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0)
        mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)

        return mkpts0, mkpts1, None, None, None, None