File size: 2,478 Bytes
78360e7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
# Copyright (c) Alibaba, Inc. and its affiliates.
import torch
import numpy as np
import argparse
from PIL import Image

def convert_to_numpy(image):
    if isinstance(image, Image.Image):
        image = np.array(image)
    elif isinstance(image, torch.Tensor):
        image = image.detach().cpu().numpy()
    elif isinstance(image, np.ndarray):
        image = image.copy()
    else:
        raise f'Unsurpport datatype{type(image)}, only surpport np.ndarray, torch.Tensor, Pillow Image.'
    return image

class FlowAnnotator:
    def __init__(self, cfg, device=None):
        from .raft.raft import RAFT
        from .raft.utils.utils import InputPadder
        from .raft.utils import flow_viz

        params = {
            "small": False,
            "mixed_precision": False,
            "alternate_corr": False
        }
        params = argparse.Namespace(**params)
        pretrained_model = cfg['PRETRAINED_MODEL']
        self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu") if device is None else device
        self.model = RAFT(params)
        self.model.load_state_dict({k.replace('module.', ''): v for k, v in torch.load(pretrained_model, map_location="cpu", weights_only=True).items()})
        self.model = self.model.to(self.device).eval()
        self.InputPadder = InputPadder
        self.flow_viz = flow_viz

    def forward(self, frames):
        # frames / RGB
        frames = [torch.from_numpy(convert_to_numpy(frame).astype(np.uint8)).permute(2, 0, 1).float()[None].to(self.device) for frame in frames]
        flow_up_list, flow_up_vis_list = [], []
        with torch.no_grad():
            for i, (image1, image2) in enumerate(zip(frames[:-1], frames[1:])):
                padder = self.InputPadder(image1.shape)
                image1, image2 = padder.pad(image1, image2)
                flow_low, flow_up = self.model(image1, image2, iters=20, test_mode=True)
                flow_up = flow_up[0].permute(1, 2, 0).cpu().numpy()
                flow_up_vis = self.flow_viz.flow_to_image(flow_up)
                flow_up_list.append(flow_up)
                flow_up_vis_list.append(flow_up_vis)
        return flow_up_list, flow_up_vis_list  # RGB


class FlowVisAnnotator(FlowAnnotator):
    def forward(self, frames):
        flow_up_list, flow_up_vis_list = super().forward(frames)
        return flow_up_vis_list[:1] + flow_up_vis_list