|
from PIL import Image |
|
import torch |
|
import torch.nn.functional as F |
|
import cv2 |
|
import numpy as np |
|
from matching import BaseMatcher |
|
from matching.utils import add_to_path |
|
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("ALIKED")) |
|
add_to_path(THIRD_PARTY_DIR.joinpath("vggt")) |
|
from nets.aliked import ALIKED |
|
from vggt.models.vggt import VGGT |
|
|
|
def torch_to_cv2(tensor): |
|
"""Convert CxHxW [0,1] tensor to OpenCV-style output""" |
|
tensor = tensor.clone().mul(255).permute(1, 2, 0) |
|
numpy_img = tensor.byte().cpu().numpy() |
|
if numpy_img.shape[2] == 3: |
|
numpy_img = cv2.cvtColor(numpy_img, cv2.COLOR_RGB2BGR) |
|
return numpy_img |
|
|
|
class VGGTMatcher(BaseMatcher): |
|
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.model = VGGT.from_pretrained("facebook/VGGT-1B").to(device) |
|
self.query_key_point_finder = ALIKED( |
|
model_name="aliked-n16rot", |
|
device=device, |
|
top_k=-1, |
|
scores_th=0.8, |
|
n_limit=max_num_keypoints |
|
) |
|
self.target_size = 518 |
|
self.patch_size = 14 |
|
self.device = device |
|
|
|
def preprocess(self, img, mode="crop"): |
|
""" |
|
Preprocess a single image tensor for model input. |
|
Returns: (batched tensor of shape (1, 3, H, W), (original_height, original_width)) |
|
""" |
|
if not isinstance(img, torch.Tensor): |
|
raise TypeError("Input must be a torch.Tensor") |
|
if img.dim() != 3 or img.shape[0] != 3: |
|
raise ValueError("Image must have shape (3, H, W)") |
|
if mode not in ["crop", "pad"]: |
|
raise ValueError("Mode must be either 'crop' or 'pad'") |
|
|
|
_, height, width = img.shape |
|
orig_shape = (height, width) |
|
|
|
if mode == "pad": |
|
if width >= height: |
|
new_width = self.target_size |
|
new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size |
|
else: |
|
new_height = self.target_size |
|
new_width = round(width * (new_height / height) / self.patch_size) * self.patch_size |
|
else: |
|
new_width = self.target_size |
|
new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size |
|
|
|
img = F.interpolate( |
|
img.unsqueeze(0), size=(new_height, new_width), mode="bicubic", align_corners=False |
|
).squeeze(0) |
|
|
|
if mode == "crop" and new_height > self.target_size: |
|
start_y = (new_height - self.target_size) // 2 |
|
img = img[:, start_y : start_y + self.target_size, :] |
|
|
|
if mode == "pad": |
|
h_padding = self.target_size - img.shape[1] |
|
w_padding = self.target_size - img.shape[2] |
|
if h_padding > 0 or w_padding > 0: |
|
pad_top = h_padding // 2 |
|
pad_bottom = h_padding - pad_top |
|
pad_left = w_padding // 2 |
|
pad_right = w_padding - pad_left |
|
img = F.pad( |
|
img, |
|
(pad_left, pad_right, pad_top, pad_bottom), |
|
mode="constant", |
|
value=1.0, |
|
) |
|
|
|
return img.unsqueeze(0), orig_shape |
|
|
|
def _forward(self, img0, img1): |
|
|
|
query_image_tensor, img0_orig_shape = self.preprocess(img0) |
|
reference_image_tensor, img1_orig_shape = self.preprocess(img1) |
|
|
|
|
|
query_image_cv2 = torch_to_cv2(query_image_tensor.squeeze(0)) |
|
|
|
pred = self.query_key_point_finder.run(query_image_cv2) |
|
mkpts0 = torch.tensor(pred['keypoints'], dtype=torch.float32, device=self.device) |
|
|
|
|
|
H0, W0 = query_image_tensor.shape[-2:] |
|
H1, W1 = reference_image_tensor.shape[-2:] |
|
|
|
|
|
mkpts0_for_model = torch.tensor( |
|
self.rescale_coords(mkpts0, H1, W1, H0, W0), |
|
dtype=torch.float32, |
|
device=self.device |
|
) |
|
|
|
|
|
pred = self.model(reference_image_tensor, query_points=mkpts0_for_model) |
|
|
|
mkpts1 = pred['track'].squeeze() |
|
|
|
|
|
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1) |
|
|
|
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0 ) |
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|
|
|