|
from kornia.feature import DeDoDe, LightGlue |
|
import torch |
|
import kornia |
|
|
|
from matching.utils import to_numpy |
|
from matching import get_version, BaseMatcher |
|
|
|
|
|
class DeDoDeLightGlue(BaseMatcher): |
|
|
|
detector_options = ["L-upright", "L-C4", "L-SO2", "L-C4-v2"] |
|
descriptor_options = ["B-upright", "G-upright", "B-C4", "B-SO2", "G-C4"] |
|
|
|
def __init__( |
|
self, |
|
device="cpu", |
|
detector_weights="L-C4-v2", |
|
desc_weights="B-upright", |
|
**kwargs, |
|
): |
|
super().__init__(device, **kwargs) |
|
|
|
major, minor, patch = get_version(kornia) |
|
assert major > 1 or ( |
|
minor >= 7 and patch >= 3 |
|
), "DeDoDe-LG only available in kornia v 0.7.3 or greater. Update kornia to use this model." |
|
|
|
assert ( |
|
detector_weights in DeDoDeLightGlue.detector_options |
|
), f"Invalid detector weights passed ({detector_weights}). Choose from {DeDoDeLightGlue.detector_options}" |
|
assert ( |
|
desc_weights in DeDoDeLightGlue.descriptor_options |
|
), f"Invalid descriptor weights passed ({desc_weights}). Choose from {DeDoDeLightGlue.descriptor_options}" |
|
|
|
desc_type = desc_weights[0].lower() |
|
self.model = DeDoDe.from_pretrained( |
|
detector_weights=detector_weights, |
|
descriptor_weights=desc_weights, |
|
amp_dtype=torch.float16 if "cuda" in device else torch.float32, |
|
).to(device) |
|
self.lg = LightGlue(features="dedode" + desc_type).to(device).eval() |
|
|
|
def preprocess(self, img): |
|
|
|
|
|
return img.unsqueeze(0) if img.ndim < 4 else img |
|
|
|
def _forward(self, img0, img1): |
|
img0 = self.preprocess(img0) |
|
img1 = self.preprocess(img1) |
|
|
|
kpts0, scores0, desc0 = self.model(img0) |
|
kpts1, scores1, desc1 = self.model(img1) |
|
|
|
match_input = { |
|
"image0": { |
|
"keypoints": kpts0, |
|
"descriptors": desc0, |
|
"image_size": torch.tensor(img0.shape[-2:][::-1]).view(1, 2).to(self.device), |
|
}, |
|
"image1": { |
|
"keypoints": kpts1, |
|
"descriptors": desc1, |
|
"image_size": torch.tensor(img1.shape[-2:][::-1]).view(1, 2).to(self.device), |
|
}, |
|
} |
|
|
|
matches = self.lg(match_input) |
|
|
|
matching_idxs = matches["matches"][0] |
|
mkpts0 = kpts0.squeeze()[matching_idxs[:, 0]] |
|
mkpts1 = kpts1.squeeze()[matching_idxs[:, 1]] |
|
|
|
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1 |
|
|