Pawel Piwowarski
init commit
0a82b18
import torch
import torchvision.transforms as tfm
from kornia.augmentation import PadTo
from kornia.utils import tensor_to_image
import tempfile
from pathlib import Path
from matching import BaseMatcher, THIRD_PARTY_DIR
from matching.utils import add_to_path
add_to_path(THIRD_PARTY_DIR.joinpath("RoMa"))
from romatch import roma_outdoor, tiny_roma_v1_outdoor
from PIL import Image
from skimage.util import img_as_ubyte
class RomaMatcher(BaseMatcher):
dino_patch_size = 14
coarse_ratio = 560 / 864
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
super().__init__(device, **kwargs)
self.roma_model = roma_outdoor(device=device)
self.max_keypoints = max_num_keypoints
self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.roma_model.train(False)
def compute_padding(self, img0, img1):
_, h0, w0 = img0.shape
_, h1, w1 = img1.shape
pad_dim = max(h0, w0, h1, w1)
self.pad = PadTo((pad_dim, pad_dim), keepdim=True)
def preprocess(self, img: torch.Tensor, pad=False) -> Image:
if isinstance(img, torch.Tensor) and img.dtype == (torch.float):
img = torch.clamp(img, -1, 1)
if pad:
img = self.pad(img)
img = tensor_to_image(img)
pil_img = Image.fromarray(img_as_ubyte(img), mode="RGB")
temp = tempfile.NamedTemporaryFile("w+b", suffix=".png", delete=False)
pil_img.save(temp.name, format="png")
return temp, pil_img.size
def _forward(self, img0, img1, pad=False):
if pad:
self.compute_padding(img0, img1)
img0_temp, img0_size = self.preprocess(img0)
img1_temp, img1_size = self.preprocess(img1)
w0, h0 = img0_size
w1, h1 = img1_size
warp, certainty = self.roma_model.match(img0_temp.name, img1_temp.name, batched=False, device=self.device)
img0_temp.close(), img1_temp.close()
Path(img0_temp.name).unlink()
Path(img1_temp.name).unlink()
matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints)
mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1)
return mkpts0, mkpts1, None, None, None, None
class TinyRomaMatcher(BaseMatcher):
def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
super().__init__(device, **kwargs)
self.roma_model = tiny_roma_v1_outdoor(device=device)
self.max_keypoints = max_num_keypoints
self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
self.roma_model.train(False)
def preprocess(self, img):
return self.normalize(img).unsqueeze(0)
def _forward(self, img0, img1):
img0 = self.preprocess(img0)
img1 = self.preprocess(img1)
h0, w0 = img0.shape[-2:]
h1, w1 = img1.shape[-2:]
# batch = {"im_A": img0.to(self.device), "im_B": img1.to(self.device)}
warp, certainty = self.roma_model.match(img0, img1, batched=False)
matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints)
mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1)
return mkpts0, mkpts1, None, None, None, None