|
import torch |
|
import torchvision.transforms as tfm |
|
from kornia.augmentation import PadTo |
|
from kornia.utils import tensor_to_image |
|
import tempfile |
|
from pathlib import Path |
|
|
|
|
|
from matching import BaseMatcher, THIRD_PARTY_DIR |
|
from matching.utils import add_to_path |
|
|
|
add_to_path(THIRD_PARTY_DIR.joinpath("RoMa")) |
|
from romatch import roma_outdoor, tiny_roma_v1_outdoor |
|
|
|
from PIL import Image |
|
from skimage.util import img_as_ubyte |
|
|
|
|
|
class RomaMatcher(BaseMatcher): |
|
dino_patch_size = 14 |
|
coarse_ratio = 560 / 864 |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.roma_model = roma_outdoor(device=device) |
|
self.max_keypoints = max_num_keypoints |
|
self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
self.roma_model.train(False) |
|
|
|
def compute_padding(self, img0, img1): |
|
_, h0, w0 = img0.shape |
|
_, h1, w1 = img1.shape |
|
pad_dim = max(h0, w0, h1, w1) |
|
|
|
self.pad = PadTo((pad_dim, pad_dim), keepdim=True) |
|
|
|
def preprocess(self, img: torch.Tensor, pad=False) -> Image: |
|
if isinstance(img, torch.Tensor) and img.dtype == (torch.float): |
|
img = torch.clamp(img, -1, 1) |
|
if pad: |
|
img = self.pad(img) |
|
img = tensor_to_image(img) |
|
pil_img = Image.fromarray(img_as_ubyte(img), mode="RGB") |
|
temp = tempfile.NamedTemporaryFile("w+b", suffix=".png", delete=False) |
|
pil_img.save(temp.name, format="png") |
|
return temp, pil_img.size |
|
|
|
def _forward(self, img0, img1, pad=False): |
|
if pad: |
|
self.compute_padding(img0, img1) |
|
img0_temp, img0_size = self.preprocess(img0) |
|
img1_temp, img1_size = self.preprocess(img1) |
|
w0, h0 = img0_size |
|
w1, h1 = img1_size |
|
|
|
warp, certainty = self.roma_model.match(img0_temp.name, img1_temp.name, batched=False, device=self.device) |
|
|
|
img0_temp.close(), img1_temp.close() |
|
Path(img0_temp.name).unlink() |
|
Path(img1_temp.name).unlink() |
|
|
|
matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints) |
|
mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1) |
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|
|
|
|
class TinyRomaMatcher(BaseMatcher): |
|
|
|
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs): |
|
super().__init__(device, **kwargs) |
|
self.roma_model = tiny_roma_v1_outdoor(device=device) |
|
self.max_keypoints = max_num_keypoints |
|
self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]) |
|
self.roma_model.train(False) |
|
|
|
def preprocess(self, img): |
|
return self.normalize(img).unsqueeze(0) |
|
|
|
def _forward(self, img0, img1): |
|
img0 = self.preprocess(img0) |
|
img1 = self.preprocess(img1) |
|
|
|
h0, w0 = img0.shape[-2:] |
|
h1, w1 = img1.shape[-2:] |
|
|
|
|
|
warp, certainty = self.roma_model.match(img0, img1, batched=False) |
|
|
|
matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints) |
|
mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1) |
|
|
|
return mkpts0, mkpts1, None, None, None, None |
|
|