Pawel Piwowarski
init commit
0a82b18
import py3_wget
from kornia.color import rgb_to_grayscale
import shutil
from matching import BaseMatcher, THIRD_PARTY_DIR
from matching.utils import add_to_path
BASE_PATH = THIRD_PARTY_DIR.joinpath("silk")
add_to_path(BASE_PATH)
def setup_silk():
# silk is meant to be installed with a symlink from the lib/ to silk/
# this often doesnt work (see https://github.com/facebookresearch/silk/issues/32)
# this solves the issue with the "ugly but works" method
silk_dir = BASE_PATH.joinpath("silk")
if not silk_dir.exists():
lib_dir = BASE_PATH.joinpath("lib")
assert lib_dir.exists() and lib_dir.is_dir()
shutil.copytree(lib_dir, silk_dir)
return None
try:
from scripts.examples.common import get_model
from silk.backbones.silk.silk import from_feature_coords_to_image_coords
from silk.models.silk import matcher
except ModuleNotFoundError:
setup_silk()
from scripts.examples.common import get_model
from silk.backbones.silk.silk import from_feature_coords_to_image_coords
from silk.models.silk import matcher
class SilkMatcher(BaseMatcher):
# reference: https://github.com/facebookresearch/silk/blob/main/scripts/examples/silk-inference.py
CKPT_DOWNLOAD_SRC = "https://dl.fbaipublicfiles.com/silk/assets/models/silk/"
CKPT_DIR = BASE_PATH.joinpath(r"assets/models/silk")
MATCHER_POSTPROCESS_OPTIONS = ["ratio-test", "mnn", "double-softmax"]
def __init__(
self,
device="cpu",
matcher_post_processing="ratio-test",
matcher_thresh=0.8,
**kwargs,
):
super().__init__(device, **kwargs)
SilkMatcher.CKPT_DIR.mkdir(exist_ok=True)
self.download_weights()
self.model = get_model(device=device, default_outputs=("sparse_positions", "sparse_descriptors"))
assert (
matcher_post_processing in SilkMatcher.MATCHER_POSTPROCESS_OPTIONS
), f"Matcher postprocessing must be one of {SilkMatcher.MATCHER_POSTPROCESS_OPTIONS}"
self.matcher = matcher(postprocessing=matcher_post_processing, threshold=matcher_thresh)
def download_weights(self):
ckpt_name = "coco-rgb-aug.ckpt"
ckpt_path = SilkMatcher.CKPT_DIR.joinpath(ckpt_name)
if not ckpt_path.exists():
print(f"Downloading {ckpt_name}")
py3_wget.download_file(SilkMatcher.CKPT_DOWNLOAD_SRC + ckpt_name, ckpt_path)
def preprocess(self, img):
# expects float img (0-1) with channel dim
if img.ndim == 3:
img = img.unsqueeze(0)
return rgb_to_grayscale(img)
def _forward(self, img0, img1):
img0 = self.preprocess(img0)
img1 = self.preprocess(img1)
sparse_positions_0, sparse_descriptors_0 = self.model(img0)
sparse_positions_1, sparse_descriptors_1 = self.model(img1)
# x, y, conf
sparse_positions_0 = from_feature_coords_to_image_coords(self.model, sparse_positions_0)
sparse_positions_1 = from_feature_coords_to_image_coords(self.model, sparse_positions_1)
# get matches
matches = self.matcher(sparse_descriptors_0[0], sparse_descriptors_1[0])
# get matching pts
mkpts0 = sparse_positions_0[0][matches[:, 0]].detach().cpu().numpy()[:, :2]
mkpts1 = sparse_positions_1[0][matches[:, 1]].detach().cpu().numpy()[:, :2]
# convert kpts to col, row (x,y) order
mkpts0 = mkpts0[:, [1, 0]]
mkpts1 = mkpts1[:, [1, 0]]
kpts0 = to_numpy(sparse_positions_0[0][:, :2])[:, [1, 0]]
kpts1 = to_numpy(sparse_positions_1[0][:, :2])[:, [1, 0]]
desc0 = to_numpy(sparse_descriptors_0[0])
desc1 = to_numpy(sparse_descriptors_1[0])
return mkpts0, mkpts1, kpts0, kpts1, desc0, desc1