Pawel Piwowarski
init commit
0a82b18
from PIL import Image
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from matching import BaseMatcher
from matching.utils import add_to_path
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR
add_to_path(THIRD_PARTY_DIR.joinpath("ALIKED"))
add_to_path(THIRD_PARTY_DIR.joinpath("vggt"))
from nets.aliked import ALIKED
from vggt.models.vggt import VGGT
def torch_to_cv2(tensor):
"""Convert CxHxW [0,1] tensor to OpenCV-style output"""
tensor = tensor.clone().mul(255).permute(1, 2, 0)
numpy_img = tensor.byte().cpu().numpy()
if numpy_img.shape[2] == 3:
numpy_img = cv2.cvtColor(numpy_img, cv2.COLOR_RGB2BGR)
return numpy_img
class VGGTMatcher(BaseMatcher):
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
super().__init__(device, **kwargs)
self.model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
self.query_key_point_finder = ALIKED(
model_name="aliked-n16rot",
device=device,
top_k=-1,
scores_th=0.8,
n_limit=max_num_keypoints
)
self.target_size = 518
self.patch_size = 14
self.device = device
def preprocess(self, img, mode="crop"):
"""
Preprocess a single image tensor for model input.
Returns: (batched tensor of shape (1, 3, H, W), (original_height, original_width))
"""
if not isinstance(img, torch.Tensor):
raise TypeError("Input must be a torch.Tensor")
if img.dim() != 3 or img.shape[0] != 3:
raise ValueError("Image must have shape (3, H, W)")
if mode not in ["crop", "pad"]:
raise ValueError("Mode must be either 'crop' or 'pad'")
_, height, width = img.shape
orig_shape = (height, width)
if mode == "pad":
if width >= height:
new_width = self.target_size
new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size
else:
new_height = self.target_size
new_width = round(width * (new_height / height) / self.patch_size) * self.patch_size
else: # mode == "crop"
new_width = self.target_size
new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size
img = F.interpolate(
img.unsqueeze(0), size=(new_height, new_width), mode="bicubic", align_corners=False
).squeeze(0)
if mode == "crop" and new_height > self.target_size:
start_y = (new_height - self.target_size) // 2
img = img[:, start_y : start_y + self.target_size, :]
if mode == "pad":
h_padding = self.target_size - img.shape[1]
w_padding = self.target_size - img.shape[2]
if h_padding > 0 or w_padding > 0:
pad_top = h_padding // 2
pad_bottom = h_padding - pad_top
pad_left = w_padding // 2
pad_right = w_padding - pad_left
img = F.pad(
img,
(pad_left, pad_right, pad_top, pad_bottom),
mode="constant",
value=1.0,
)
return img.unsqueeze(0), orig_shape
def _forward(self, img0, img1):
# Preprocess both images to model input size
query_image_tensor, img0_orig_shape = self.preprocess(img0)
reference_image_tensor, img1_orig_shape = self.preprocess(img1)
# Convert the query image to OpenCV format for ALIKED
query_image_cv2 = torch_to_cv2(query_image_tensor.squeeze(0))
# Run ALIKED on the preprocessed query image
pred = self.query_key_point_finder.run(query_image_cv2)
mkpts0 = torch.tensor(pred['keypoints'], dtype=torch.float32, device=self.device)
# Get the model input sizes
H0, W0 = query_image_tensor.shape[-2:]
H1, W1 = reference_image_tensor.shape[-2:]
# Rescale mkpts0 from ALIKED image size (query_image_tensor) to reference image size (reference_image_tensor)
mkpts0_for_model = torch.tensor(
self.rescale_coords(mkpts0, H1, W1, H0, W0),
dtype=torch.float32,
device=self.device
)
# Forward pass to VGGT with the rescaled query points
pred = self.model(reference_image_tensor, query_points=mkpts0_for_model)
mkpts1 = pred['track'].squeeze()
# Rescale mkpts1 from reference image size (VGGT input) to original candidate image size
mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
# Rescale mkpts0 from query image size (VGGT input) to original query image size
mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0 )
return mkpts0, mkpts1, None, None, None, None