|
import py3_wget |
|
from kornia.color import rgb_to_grayscale |
|
import shutil |
|
|
|
from matching import BaseMatcher, THIRD_PARTY_DIR |
|
from matching.utils import add_to_path |
|
|
|
BASE_PATH = THIRD_PARTY_DIR.joinpath("silk") |
|
add_to_path(BASE_PATH) |
|
|
|
|
|
def setup_silk(): |
|
|
|
|
|
|
|
silk_dir = BASE_PATH.joinpath("silk") |
|
if not silk_dir.exists(): |
|
lib_dir = BASE_PATH.joinpath("lib") |
|
assert lib_dir.exists() and lib_dir.is_dir() |
|
shutil.copytree(lib_dir, silk_dir) |
|
return None |
|
|
|
|
|
try: |
|
from scripts.examples.common import get_model |
|
from silk.backbones.silk.silk import from_feature_coords_to_image_coords |
|
from silk.models.silk import matcher |
|
except ModuleNotFoundError: |
|
setup_silk() |
|
from scripts.examples.common import get_model |
|
from silk.backbones.silk.silk import from_feature_coords_to_image_coords |
|
from silk.models.silk import matcher |
|
|
|
|
|
class SilkMatcher(BaseMatcher): |
|
|
|
CKPT_DOWNLOAD_SRC = "https://dl.fbaipublicfiles.com/silk/assets/models/silk/" |
|
CKPT_DIR = BASE_PATH.joinpath(r"assets/models/silk") |
|
|
|
MATCHER_POSTPROCESS_OPTIONS = ["ratio-test", "mnn", "double-softmax"] |
|
|
|
def __init__( |
|
self, |
|
device="cpu", |
|
matcher_post_processing="ratio-test", |
|
matcher_thresh=0.8, |
|
**kwargs, |
|
): |
|
super().__init__(device, **kwargs) |
|
SilkMatcher.CKPT_DIR.mkdir(exist_ok=True) |
|
|
|
self.download_weights() |
|
|
|
self.model = get_model(device=device, default_outputs=("sparse_positions", "sparse_descriptors")) |
|
|
|
assert ( |
|
matcher_post_processing in SilkMatcher.MATCHER_POSTPROCESS_OPTIONS |
|
), f"Matcher postprocessing must be one of {SilkMatcher.MATCHER_POSTPROCESS_OPTIONS}" |
|
self.matcher = matcher(postprocessing=matcher_post_processing, threshold=matcher_thresh) |
|
|
|
def download_weights(self): |
|
ckpt_name = "coco-rgb-aug.ckpt" |
|
ckpt_path = SilkMatcher.CKPT_DIR.joinpath(ckpt_name) |
|
if not ckpt_path.exists(): |
|
print(f"Downloading {ckpt_name}") |
|
py3_wget.download_file(SilkMatcher.CKPT_DOWNLOAD_SRC + ckpt_name, ckpt_path) |
|
|
|
def preprocess(self, img): |
|
|
|
if img.ndim == 3: |
|
img = img.unsqueeze(0) |
|
return rgb_to_grayscale(img) |
|
|
|
def _forward(self, img0, img1): |
|
img0 = self.preprocess(img0) |
|
img1 = self.preprocess(img1) |
|
|
|
sparse_positions_0, sparse_descriptors_0 = self.model(img0) |
|
sparse_positions_1, sparse_descriptors_1 = self.model(img1) |
|
|
|
|
|
sparse_positions_0 = from_feature_coords_to_image_coords(self.model, sparse_positions_0) |
|
sparse_positions_1 = from_feature_coords_to_image_coords(self.model, sparse_positions_1) |
|
|
|
|
|
matches = self.matcher(sparse_descriptors_0[0], sparse_descriptors_1[0]) |
|
|
|
|
|
mkpts0 = sparse_positions_0[0][matches[:, 0]].detach().cpu().numpy()[:, :2] |
|
mkpts1 = sparse_positions_1[0][matches[:, 1]].detach().cpu().numpy()[:, :2] |
|
|
|
|
|
mkpts0 = mkpts0[:, [1, 0]] |
|
mkpts1 = mkpts1[:, [1, 0]] |
|
|
|
kpts0 = to_numpy(sparse_positions_0[0][:, :2])[:, [1, 0]] |
|
kpts1 = to_numpy(sparse_positions_1[0][:, :2])[:, [1, 0]] |
|
desc0 = to_numpy(sparse_descriptors_0[0]) |
|
desc1 = to_numpy(sparse_descriptors_1[0]) |
|
|
|
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1 |
|
|