|
import torch |
|
import os |
|
import torchvision.transforms as tfm |
|
import py3_wget |
|
from matching import BaseMatcher, THIRD_PARTY_DIR, WEIGHTS_DIR |
|
from matching.utils import resize_to_divisible, add_to_path |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("affine-steerers")) |
|
from affine_steerers.utils import build_affine |
|
from affine_steerers.matchers.dual_softmax_matcher import DualSoftMaxMatcher, MaxSimilarityMatcher |
|
from affine_steerers import dedode_detector_L, dedode_descriptor_B, dedode_descriptor_G |
|
|
|
class AffSteererMatcher(BaseMatcher): |
|
detector_path_L = WEIGHTS_DIR.joinpath("dedode_detector_C4_affsteerers.pth") |
|
|
|
descriptor_path_equi_G = WEIGHTS_DIR.joinpath("descriptor_aff_equi_G.pth") |
|
descriptor_path_steer_G = WEIGHTS_DIR.joinpath("descriptor_aff_steer_G.pth") |
|
|
|
descriptor_path_equi_B = WEIGHTS_DIR.joinpath("descriptor_aff_equi_B.pth") |
|
descriptor_path_steer_B = WEIGHTS_DIR.joinpath("descriptor_aff_steer_B.pth") |
|
|
|
steerer_path_equi_G = WEIGHTS_DIR.joinpath("steerer_aff_equi_G.pth") |
|
steerer_path_steer_G = WEIGHTS_DIR.joinpath("steerer_aff_steer_G.pth") |
|
|
|
steerer_path_equi_B = WEIGHTS_DIR.joinpath("steerer_aff_equi_B.pth") |
|
steerer_path_steer_B = WEIGHTS_DIR.joinpath("steerer_aff_steer_B.pth") |
|
|
|
dino_patch_size = 14 |
|
|
|
STEERER_TYPES = ["equi_G", "steer_G", "equi_B", "steer_B"] |
|
|
|
def __init__( |
|
self, |
|
device="cpu", |
|
max_num_keypoints=10_000, |
|
steerer_type="equi_G", |
|
match_threshold=0.01, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__(device, **kwargs) |
|
|
|
if self.device != "cuda": |
|
raise ValueError("Only device 'cuda' supported for AffineSteerers.") |
|
|
|
WEIGHTS_DIR.mkdir(exist_ok=True) |
|
|
|
self.steerer_type = steerer_type |
|
if self.steerer_type not in self.STEERER_TYPES: |
|
raise ValueError(f'unsupported type for aff-steerer: {steerer_type}. Must choose from {self.STEERER_TYPES}.') |
|
|
|
|
|
|
|
self.download_weights() |
|
|
|
self.max_keypoints = max_num_keypoints |
|
self.threshold = match_threshold |
|
|
|
self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
|
|
|
|
self.detector, self.descriptor, self.steerer, self.matcher = self.build_matcher() |
|
|
|
def download_weights(self): |
|
if not AffSteererMatcher.detector_path_L.exists(): |
|
print("Downloading dedode_detector_C4.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/dedode_detector_C4.pth", |
|
AffSteererMatcher.detector_path_L, |
|
) |
|
|
|
|
|
if self.steerer_type == "equi_G" and not AffSteererMatcher.descriptor_path_equi_G.exists(): |
|
print("Downloading descriptor_aff_equi_G.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/descriptor_aff_equi_G.pth", |
|
AffSteererMatcher.descriptor_path_equi_G, |
|
) |
|
|
|
if self.steerer_type == "steer_G" and not AffSteererMatcher.descriptor_path_steer_G.exists(): |
|
print("Downloading descriptor_aff_steer_G.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/descriptor_aff_steer_G.pth", |
|
AffSteererMatcher.descriptor_path_steer_G, |
|
) |
|
|
|
if self.steerer_type == "equi_B" and not AffSteererMatcher.descriptor_path_equi_B.exists(): |
|
print("Downloading descriptor_aff_equi_B.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/descriptor_aff_equi_B.pth", |
|
AffSteererMatcher.descriptor_path_equi_B, |
|
) |
|
|
|
if self.steerer_type == "steer_B" and not AffSteererMatcher.descriptor_path_steer_B.exists(): |
|
print("Downloading descriptor_aff_steer_B.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/descriptor_aff_steer_B.pth", |
|
AffSteererMatcher.descriptor_path_steer_B, |
|
) |
|
|
|
|
|
if self.steerer_type == "equi_G" and not AffSteererMatcher.steerer_path_equi_G.exists(): |
|
print("Downloading steerer_aff_equi_G.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/steerer_aff_equi_G.pth", |
|
AffSteererMatcher.steerer_path_equi_G, |
|
) |
|
if self.steerer_type == "steer_G" and not AffSteererMatcher.steerer_path_steer_G.exists(): |
|
print("Downloading steerer_aff_steer_G.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/steerer_aff_steer_G.pth", |
|
AffSteererMatcher.steerer_path_steer_G, |
|
) |
|
|
|
if self.steerer_type == "equi_B" and not AffSteererMatcher.steerer_path_equi_B.exists(): |
|
print("Downloading steerer_aff_equi_B.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/steerer_aff_equi_B.pth", |
|
AffSteererMatcher.steerer_path_equi_B, |
|
) |
|
if self.steerer_type == "steer_B" and not AffSteererMatcher.steerer_path_steer_B.exists(): |
|
print("Downloading steerer_aff_steer_B.pth") |
|
py3_wget.download_file( |
|
"https://github.com/georg-bn/affine-steerers/releases/download/weights/steerer_aff_steer_B.pth", |
|
AffSteererMatcher.steerer_path_steer_B, |
|
) |
|
|
|
def build_matcher(self): |
|
detector = dedode_detector_L(weights=torch.load(self.detector_path_L, map_location=self.device)) |
|
|
|
if "G" in self.steerer_type: |
|
descriptor_path = self.descriptor_path_equi_G if 'equi' in self.steerer_type else self.descriptor_path_steer_G |
|
descriptor = dedode_descriptor_G( |
|
weights=torch.load(descriptor_path, map_location=self.device) |
|
) |
|
else: |
|
descriptor_path = self.descriptor_path_equi_B if 'equi' in self.steerer_type else self.descriptor_path_steer_B |
|
descriptor = dedode_descriptor_B( |
|
weights=torch.load(self.descriptor_path, map_location=self.device) |
|
) |
|
|
|
if "G" in self.steerer_type: |
|
steerer_path = self.steerer_path_equi_G if 'equi' in self.steerer_type else self.steerer_path_steer_G |
|
else: |
|
steerer_path = self.steerer_path_equi_B if 'equi' in self.steerer_type else self.steerer_path_steer_B |
|
|
|
assert steerer_path.exists(), f"could not find steerer weights at {steerer_path}. Please check that they exist." |
|
steerer = self.load_steerer( |
|
steerer_path |
|
).to(self.device).eval() |
|
|
|
steerer.use_prototype_affines = True |
|
|
|
if 'steer' not in self.steerer_type: |
|
steerer.prototype_affines = torch.stack( |
|
[ |
|
build_affine( |
|
angle_1=0., |
|
dilation_1=1., |
|
dilation_2=1., |
|
angle_2=r * 2 * torch.pi / 8 |
|
) |
|
for r in range(8) |
|
], |
|
dim=0, |
|
).to(self.device) |
|
|
|
matcher = MaxSimilarityMatcher( |
|
steerer=steerer, |
|
normalize=False, |
|
inv_temp=5, |
|
threshold=self.threshold |
|
) |
|
if self.device == "cpu": |
|
detector = detector.to(torch.float32) |
|
descriptor = descriptor.to(torch.float32) |
|
steerer = steerer.to(torch.float32) |
|
return detector, descriptor, steerer, matcher |
|
|
|
@staticmethod |
|
def load_steerer(steerer_path, checkpoint=False, prototypes=True, feat_dim=256): |
|
from affine_steerers.steerers import SteererSpread |
|
if checkpoint: |
|
sd = torch.load(steerer_path, map_location="cpu")["steerer"] |
|
else: |
|
sd = torch.load(steerer_path, map_location="cpu") |
|
|
|
nbr_prototypes = 0 |
|
if prototypes and "prototype_affines" in sd: |
|
nbr_prototypes = sd["prototype_affines"].shape[0] |
|
|
|
steerer = SteererSpread( |
|
feat_dim=feat_dim, |
|
max_order=4, |
|
normalize=True, |
|
normalize_only_higher=False, |
|
fix_order_1_scalings=False, |
|
max_determinant_scaling=None, |
|
block_diag_rot=False, |
|
block_diag_optimal_scalings=False, |
|
learnable_determinant_scaling=True, |
|
learnable_basis=True, |
|
learnable_reference_direction=False, |
|
use_prototype_affines=prototypes and "prototype_affines" in sd, |
|
prototype_affines_init=[ |
|
torch.eye(2) |
|
for i in range(nbr_prototypes) |
|
] |
|
) |
|
steerer.load_state_dict(sd) |
|
return steerer |
|
|
|
def preprocess(self, img): |
|
|
|
_, h, w = img.shape |
|
orig_shape = h, w |
|
img = resize_to_divisible(img, self.dino_patch_size) |
|
|
|
img = self.normalize(img).unsqueeze(0).to(self.device) |
|
return img, orig_shape |
|
|
|
def _forward(self, img0, img1): |
|
img0, img0_orig_shape = self.preprocess(img0) |
|
img1, img1_orig_shape = self.preprocess(img1) |
|
|
|
batch_0 = {"image": img0} |
|
detections_0 = self.detector.detect(batch_0, num_keypoints=self.max_keypoints) |
|
keypoints_0, P_0 = detections_0["keypoints"], detections_0["confidence"] |
|
|
|
batch_1 = {"image": img1} |
|
detections_1 = self.detector.detect(batch_1, num_keypoints=self.max_keypoints) |
|
keypoints_1, P_1 = detections_1["keypoints"], detections_1["confidence"] |
|
|
|
description_0 = self.descriptor.describe_keypoints(batch_0, keypoints_0)["descriptions"] |
|
description_1 = self.descriptor.describe_keypoints(batch_1, keypoints_1)["descriptions"] |
|
|
|
matches_0, matches_1, _ = self.matcher.match( |
|
keypoints_0, |
|
description_0, |
|
keypoints_1, |
|
description_1, |
|
) |
|
|
|
H0, W0, H1, W1 = *img0.shape[-2:], *img1.shape[-2:] |
|
mkpts0, mkpts1 = self.matcher.to_pixel_coords(matches_0, matches_1, H0, W0, H1, W1) |
|
|
|
|
|
|
|
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0) |
|
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) |
|
|
|
return mkpts0, mkpts1, keypoints_0, keypoints_1, description_0, description_1 |
|
|