File size: 3,352 Bytes
0a82b18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
import torch
import torchvision.transforms as tfm
from kornia.augmentation import PadTo
from kornia.utils import tensor_to_image
import tempfile
from pathlib import Path


from matching import BaseMatcher, THIRD_PARTY_DIR
from matching.utils import add_to_path

add_to_path(THIRD_PARTY_DIR.joinpath("RoMa"))
from romatch import roma_outdoor, tiny_roma_v1_outdoor

from PIL import Image
from skimage.util import img_as_ubyte


class RomaMatcher(BaseMatcher):
    dino_patch_size = 14
    coarse_ratio = 560 / 864

    def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.roma_model = roma_outdoor(device=device)
        self.max_keypoints = max_num_keypoints
        self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.roma_model.train(False)

    def compute_padding(self, img0, img1):
        _, h0, w0 = img0.shape
        _, h1, w1 = img1.shape
        pad_dim = max(h0, w0, h1, w1)

        self.pad = PadTo((pad_dim, pad_dim), keepdim=True)

    def preprocess(self, img: torch.Tensor, pad=False) -> Image:
        if isinstance(img, torch.Tensor) and img.dtype == (torch.float):
            img = torch.clamp(img, -1, 1)
        if pad:
            img = self.pad(img)
        img = tensor_to_image(img)
        pil_img = Image.fromarray(img_as_ubyte(img), mode="RGB")
        temp = tempfile.NamedTemporaryFile("w+b", suffix=".png", delete=False)
        pil_img.save(temp.name, format="png")
        return temp, pil_img.size

    def _forward(self, img0, img1, pad=False):
        if pad:
            self.compute_padding(img0, img1)
        img0_temp, img0_size = self.preprocess(img0)
        img1_temp, img1_size = self.preprocess(img1)
        w0, h0 = img0_size
        w1, h1 = img1_size

        warp, certainty = self.roma_model.match(img0_temp.name, img1_temp.name, batched=False, device=self.device)

        img0_temp.close(), img1_temp.close()
        Path(img0_temp.name).unlink()
        Path(img1_temp.name).unlink()

        matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints)
        mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1)

        return mkpts0, mkpts1, None, None, None, None


class TinyRomaMatcher(BaseMatcher):

    def __init__(self, device="cpu", max_num_keypoints=2048, *args, **kwargs):
        super().__init__(device, **kwargs)
        self.roma_model = tiny_roma_v1_outdoor(device=device)
        self.max_keypoints = max_num_keypoints
        self.normalize = tfm.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
        self.roma_model.train(False)

    def preprocess(self, img):
        return self.normalize(img).unsqueeze(0)

    def _forward(self, img0, img1):
        img0 = self.preprocess(img0)
        img1 = self.preprocess(img1)

        h0, w0 = img0.shape[-2:]
        h1, w1 = img1.shape[-2:]

        # batch = {"im_A": img0.to(self.device), "im_B": img1.to(self.device)}
        warp, certainty = self.roma_model.match(img0, img1, batched=False)

        matches, certainty = self.roma_model.sample(warp, certainty, num=self.max_keypoints)
        mkpts0, mkpts1 = self.roma_model.to_pixel_coordinates(matches, h0, w0, h1, w1)

        return mkpts0, mkpts1, None, None, None, None