File size: 4,888 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
from PIL import Image
import torch
import torch.nn.functional as F
import cv2
import numpy as np
from matching import BaseMatcher
from matching.utils import add_to_path
from matching import WEIGHTS_DIR, THIRD_PARTY_DIR

add_to_path(THIRD_PARTY_DIR.joinpath("ALIKED"))
add_to_path(THIRD_PARTY_DIR.joinpath("vggt"))
from nets.aliked import ALIKED
from vggt.models.vggt import VGGT

def torch_to_cv2(tensor):
    """Convert CxHxW [0,1] tensor to OpenCV-style output"""
    tensor = tensor.clone().mul(255).permute(1, 2, 0)
    numpy_img = tensor.byte().cpu().numpy()
    if numpy_img.shape[2] == 3:
        numpy_img = cv2.cvtColor(numpy_img, cv2.COLOR_RGB2BGR)
    return numpy_img

class VGGTMatcher(BaseMatcher):
    def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.model = VGGT.from_pretrained("facebook/VGGT-1B").to(device)
        self.query_key_point_finder = ALIKED(
            model_name="aliked-n16rot",
            device=device,
            top_k=-1,
            scores_th=0.8,
            n_limit=max_num_keypoints
        )
        self.target_size = 518
        self.patch_size = 14
        self.device = device

    def preprocess(self, img, mode="crop"):
        """
        Preprocess a single image tensor for model input.
        Returns: (batched tensor of shape (1, 3, H, W), (original_height, original_width))
        """
        if not isinstance(img, torch.Tensor):
            raise TypeError("Input must be a torch.Tensor")
        if img.dim() != 3 or img.shape[0] != 3:
            raise ValueError("Image must have shape (3, H, W)")
        if mode not in ["crop", "pad"]:
            raise ValueError("Mode must be either 'crop' or 'pad'")

        _, height, width = img.shape
        orig_shape = (height, width)

        if mode == "pad":
            if width >= height:
                new_width = self.target_size
                new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size
            else:
                new_height = self.target_size
                new_width = round(width * (new_height / height) / self.patch_size) * self.patch_size
        else:  # mode == "crop"
            new_width = self.target_size
            new_height = round(height * (new_width / width) / self.patch_size) * self.patch_size

        img = F.interpolate(
            img.unsqueeze(0), size=(new_height, new_width), mode="bicubic", align_corners=False
        ).squeeze(0)

        if mode == "crop" and new_height > self.target_size:
            start_y = (new_height - self.target_size) // 2
            img = img[:, start_y : start_y + self.target_size, :]

        if mode == "pad":
            h_padding = self.target_size - img.shape[1]
            w_padding = self.target_size - img.shape[2]
            if h_padding > 0 or w_padding > 0:
                pad_top = h_padding // 2
                pad_bottom = h_padding - pad_top
                pad_left = w_padding // 2
                pad_right = w_padding - pad_left
                img = F.pad(
                    img,
                    (pad_left, pad_right, pad_top, pad_bottom),
                    mode="constant",
                    value=1.0,
                )

        return img.unsqueeze(0), orig_shape

    def _forward(self, img0, img1):
        # Preprocess both images to model input size
        query_image_tensor, img0_orig_shape = self.preprocess(img0)
        reference_image_tensor, img1_orig_shape = self.preprocess(img1)

        # Convert the query image to OpenCV format for ALIKED
        query_image_cv2 = torch_to_cv2(query_image_tensor.squeeze(0))
        # Run ALIKED on the preprocessed query image
        pred = self.query_key_point_finder.run(query_image_cv2)
        mkpts0 = torch.tensor(pred['keypoints'], dtype=torch.float32, device=self.device)

        # Get the model input sizes
        H0, W0 = query_image_tensor.shape[-2:]
        H1, W1 = reference_image_tensor.shape[-2:]

        # Rescale mkpts0 from ALIKED image size (query_image_tensor) to reference image size (reference_image_tensor)
        mkpts0_for_model = torch.tensor(
            self.rescale_coords(mkpts0, H1, W1, H0, W0),
            dtype=torch.float32,
            device=self.device
        )

        # Forward pass to VGGT with the rescaled query points
        pred = self.model(reference_image_tensor, query_points=mkpts0_for_model)

        mkpts1 = pred['track'].squeeze()
    
        # Rescale mkpts1 from reference image size (VGGT input) to original candidate image size
        mkpts1 = self.rescale_coords(mkpts1, *img1_orig_shape, H1, W1)
        # Rescale mkpts0 from query image size (VGGT input) to original query image size
        mkpts0 = self.rescale_coords(mkpts0, *img0_orig_shape, H0, W0 )

        return mkpts0, mkpts1, None, None, None, None