Pawel Piwowarski
init commit
0a82b18
from kornia.feature import DeDoDe, LightGlue
import torch
import kornia
from matching.utils import to_numpy
from matching import get_version, BaseMatcher
class DeDoDeLightGlue(BaseMatcher):
detector_options = ["L-upright", "L-C4", "L-SO2", "L-C4-v2"]
descriptor_options = ["B-upright", "G-upright", "B-C4", "B-SO2", "G-C4"]
def __init__(
self,
device="cpu",
detector_weights="L-C4-v2",
desc_weights="B-upright",
**kwargs,
):
super().__init__(device, **kwargs)
major, minor, patch = get_version(kornia)
assert major > 1 or (
minor >= 7 and patch >= 3
), "DeDoDe-LG only available in kornia v 0.7.3 or greater. Update kornia to use this model."
assert (
detector_weights in DeDoDeLightGlue.detector_options
), f"Invalid detector weights passed ({detector_weights}). Choose from {DeDoDeLightGlue.detector_options}"
assert (
desc_weights in DeDoDeLightGlue.descriptor_options
), f"Invalid descriptor weights passed ({desc_weights}). Choose from {DeDoDeLightGlue.descriptor_options}"
desc_type = desc_weights[0].lower()
self.model = DeDoDe.from_pretrained(
detector_weights=detector_weights,
descriptor_weights=desc_weights,
amp_dtype=torch.float16 if "cuda" in device else torch.float32,
).to(device)
self.lg = LightGlue(features="dedode" + desc_type).to(device).eval()
def preprocess(self, img):
# kornia version applies imagenet normalization
# and pads if not divisible by default
return img.unsqueeze(0) if img.ndim < 4 else img
def _forward(self, img0, img1):
img0 = self.preprocess(img0)
img1 = self.preprocess(img1)
kpts0, scores0, desc0 = self.model(img0)
kpts1, scores1, desc1 = self.model(img1)
match_input = {
"image0": {
"keypoints": kpts0,
"descriptors": desc0,
"image_size": torch.tensor(img0.shape[-2:][::-1]).view(1, 2).to(self.device),
},
"image1": {
"keypoints": kpts1,
"descriptors": desc1,
"image_size": torch.tensor(img1.shape[-2:][::-1]).view(1, 2).to(self.device),
},
}
matches = self.lg(match_input)
matching_idxs = matches["matches"][0]
mkpts0 = kpts0.squeeze()[matching_idxs[:, 0]]
mkpts1 = kpts1.squeeze()[matching_idxs[:, 1]]
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1