|
from typing import Union, List, Dict |
|
|
|
import torch |
|
from matanyone.inference.object_info import ObjectInfo |
|
|
|
|
|
class ObjectManager: |
|
""" |
|
Object IDs are immutable. The same ID always represent the same object. |
|
Temporary IDs are the positions of each object in the tensor. It changes as objects get removed. |
|
Temporary IDs start from 1. |
|
""" |
|
|
|
def __init__(self): |
|
self.obj_to_tmp_id: Dict[ObjectInfo, int] = {} |
|
self.tmp_id_to_obj: Dict[int, ObjectInfo] = {} |
|
self.obj_id_to_obj: Dict[int, ObjectInfo] = {} |
|
|
|
self.all_historical_object_ids: List[int] = [] |
|
|
|
def _recompute_obj_id_to_obj_mapping(self) -> None: |
|
self.obj_id_to_obj = {obj.id: obj for obj in self.obj_to_tmp_id} |
|
|
|
def add_new_objects( |
|
self, objects: Union[List[ObjectInfo], ObjectInfo, |
|
List[int]]) -> (List[int], List[int]): |
|
if not isinstance(objects, list): |
|
objects = [objects] |
|
|
|
corresponding_tmp_ids = [] |
|
corresponding_obj_ids = [] |
|
for obj in objects: |
|
if isinstance(obj, int): |
|
obj = ObjectInfo(id=obj) |
|
|
|
if obj in self.obj_to_tmp_id: |
|
|
|
corresponding_tmp_ids.append(self.obj_to_tmp_id[obj]) |
|
corresponding_obj_ids.append(obj.id) |
|
else: |
|
|
|
new_obj = ObjectInfo(id=obj.id) |
|
|
|
|
|
new_tmp_id = len(self.obj_to_tmp_id) + 1 |
|
self.obj_to_tmp_id[new_obj] = new_tmp_id |
|
self.tmp_id_to_obj[new_tmp_id] = new_obj |
|
self.all_historical_object_ids.append(new_obj.id) |
|
corresponding_tmp_ids.append(new_tmp_id) |
|
corresponding_obj_ids.append(new_obj.id) |
|
|
|
self._recompute_obj_id_to_obj_mapping() |
|
assert corresponding_tmp_ids == sorted(corresponding_tmp_ids) |
|
return corresponding_tmp_ids, corresponding_obj_ids |
|
|
|
def delete_objects(self, obj_ids_to_remove: Union[int, List[int]]) -> None: |
|
|
|
|
|
if isinstance(obj_ids_to_remove, int): |
|
obj_ids_to_remove = [obj_ids_to_remove] |
|
|
|
new_tmp_id = 1 |
|
total_num_id = len(self.obj_to_tmp_id) |
|
|
|
local_obj_to_tmp_id = {} |
|
local_tmp_to_obj_id = {} |
|
|
|
for tmp_iter in range(1, total_num_id + 1): |
|
obj = self.tmp_id_to_obj[tmp_iter] |
|
if obj.id not in obj_ids_to_remove: |
|
local_obj_to_tmp_id[obj] = new_tmp_id |
|
local_tmp_to_obj_id[new_tmp_id] = obj |
|
new_tmp_id += 1 |
|
|
|
self.obj_to_tmp_id = local_obj_to_tmp_id |
|
self.tmp_id_to_obj = local_tmp_to_obj_id |
|
self._recompute_obj_id_to_obj_mapping() |
|
|
|
def purge_inactive_objects(self, |
|
max_missed_detection_count: int) -> (bool, List[int], List[int]): |
|
|
|
obj_id_to_be_deleted = [] |
|
tmp_id_to_be_deleted = [] |
|
tmp_id_to_keep = [] |
|
obj_id_to_keep = [] |
|
|
|
for obj in self.obj_to_tmp_id: |
|
if obj.poke_count > max_missed_detection_count: |
|
obj_id_to_be_deleted.append(obj.id) |
|
tmp_id_to_be_deleted.append(self.obj_to_tmp_id[obj]) |
|
else: |
|
tmp_id_to_keep.append(self.obj_to_tmp_id[obj]) |
|
obj_id_to_keep.append(obj.id) |
|
|
|
purge_activated = len(obj_id_to_be_deleted) > 0 |
|
if purge_activated: |
|
self.delete_objects(obj_id_to_be_deleted) |
|
return purge_activated, tmp_id_to_keep, obj_id_to_keep |
|
|
|
def tmp_to_obj_cls(self, mask) -> torch.Tensor: |
|
|
|
new_mask = torch.zeros_like(mask) |
|
for tmp_id, obj in self.tmp_id_to_obj.items(): |
|
new_mask[mask == tmp_id] = obj.id |
|
return new_mask |
|
|
|
def get_tmp_to_obj_mapping(self) -> Dict[int, ObjectInfo]: |
|
|
|
return {obj.id: tmp_id for obj, tmp_id in self.tmp_id_to_obj.items()} |
|
|
|
def realize_dict(self, obj_dict, dim=1) -> torch.Tensor: |
|
|
|
output = [] |
|
for _, obj in self.tmp_id_to_obj.items(): |
|
if obj.id not in obj_dict: |
|
raise NotImplementedError |
|
output.append(obj_dict[obj.id]) |
|
output = torch.stack(output, dim=dim) |
|
return output |
|
|
|
def make_one_hot(self, cls_mask) -> torch.Tensor: |
|
output = [] |
|
for _, obj in self.tmp_id_to_obj.items(): |
|
output.append(cls_mask == obj.id) |
|
if len(output) == 0: |
|
output = torch.zeros((0, *cls_mask.shape), dtype=torch.bool, device=cls_mask.device) |
|
else: |
|
output = torch.stack(output, dim=0) |
|
return output |
|
|
|
@property |
|
def all_obj_ids(self) -> List[int]: |
|
return [k.id for k in self.obj_to_tmp_id] |
|
|
|
@property |
|
def num_obj(self) -> int: |
|
return len(self.obj_to_tmp_id) |
|
|
|
def has_all(self, objects: List[int]) -> bool: |
|
for obj in objects: |
|
if obj not in self.obj_to_tmp_id: |
|
return False |
|
return True |
|
|
|
def find_object_by_id(self, obj_id) -> ObjectInfo: |
|
return self.obj_id_to_obj[obj_id] |
|
|
|
def find_tmp_by_id(self, obj_id) -> int: |
|
return self.obj_to_tmp_id[self.obj_id_to_obj[obj_id]] |
|
|