File size: 2,603 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
from kornia.feature import DeDoDe, LightGlue
import torch
import kornia

from matching.utils import to_numpy
from matching import get_version, BaseMatcher


class DeDoDeLightGlue(BaseMatcher):

    detector_options = ["L-upright", "L-C4", "L-SO2", "L-C4-v2"]
    descriptor_options = ["B-upright", "G-upright", "B-C4", "B-SO2", "G-C4"]

    def __init__(
        self,
        device="cpu",
        detector_weights="L-C4-v2",
        desc_weights="B-upright",
        **kwargs,
    ):
        super().__init__(device, **kwargs)

        major, minor, patch = get_version(kornia)
        assert major > 1 or (
            minor >= 7 and patch >= 3
        ), "DeDoDe-LG only available in kornia v 0.7.3 or greater. Update kornia to use this model."

        assert (
            detector_weights in DeDoDeLightGlue.detector_options
        ), f"Invalid detector weights passed ({detector_weights}). Choose from {DeDoDeLightGlue.detector_options}"
        assert (
            desc_weights in DeDoDeLightGlue.descriptor_options
        ), f"Invalid descriptor weights passed ({desc_weights}). Choose from {DeDoDeLightGlue.descriptor_options}"

        desc_type = desc_weights[0].lower()
        self.model = DeDoDe.from_pretrained(
            detector_weights=detector_weights,
            descriptor_weights=desc_weights,
            amp_dtype=torch.float16 if "cuda" in device else torch.float32,
        ).to(device)
        self.lg = LightGlue(features="dedode" + desc_type).to(device).eval()

    def preprocess(self, img):
        # kornia version applies imagenet normalization
        # and pads if not divisible by default
        return img.unsqueeze(0) if img.ndim < 4 else img

    def _forward(self, img0, img1):
        img0 = self.preprocess(img0)
        img1 = self.preprocess(img1)

        kpts0, scores0, desc0 = self.model(img0)
        kpts1, scores1, desc1 = self.model(img1)

        match_input = {
            "image0": {
                "keypoints": kpts0,
                "descriptors": desc0,
                "image_size": torch.tensor(img0.shape[-2:][::-1]).view(1, 2).to(self.device),
            },
            "image1": {
                "keypoints": kpts1,
                "descriptors": desc1,
                "image_size": torch.tensor(img1.shape[-2:][::-1]).view(1, 2).to(self.device),
            },
        }

        matches = self.lg(match_input)

        matching_idxs = matches["matches"][0]
        mkpts0 = kpts0.squeeze()[matching_idxs[:, 0]]
        mkpts1 = kpts1.squeeze()[matching_idxs[:, 1]]

        return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1