Spaces:
Running
Running
| # -*- coding: UTF-8 -*- | |
| '''================================================= | |
| @Project -> File pram -> multimap3d | |
| @IDE PyCharm | |
| @Author fx221@cam.ac.uk | |
| @Date 04/03/2024 13:47 | |
| ==================================================''' | |
| import numpy as np | |
| import os | |
| import os.path as osp | |
| import time | |
| import cv2 | |
| import torch | |
| import yaml | |
| from copy import deepcopy | |
| from recognition.vis_seg import vis_seg_point, generate_color_dic, vis_inlier, plot_matches | |
| from localization.base_model import dynamic_load | |
| import localization.matchers as matchers | |
| from localization.match_features_batch import confs as matcher_confs | |
| from nets.gm import GM | |
| from tools.common import resize_img | |
| from localization.singlemap3d import SingleMap3D | |
| from localization.frame import Frame | |
| class MultiMap3D: | |
| def __init__(self, config, viewer=None, save_dir=None): | |
| self.config = config | |
| self.save_dir = save_dir | |
| self.scenes = [] | |
| self.sid_scene_name = [] | |
| self.sub_maps = {} | |
| self.scene_name_start_sid = {} | |
| self.loc_config = config['localization'] | |
| self.save_dir = save_dir | |
| if self.save_dir is not None: | |
| os.makedirs(self.save_dir, exist_ok=True) | |
| self.matching_method = config['localization']['matching_method'] | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| Model = dynamic_load(matchers, self.matching_method) | |
| self.matcher = Model(matcher_confs[self.matching_method]['model']).eval().to(device) | |
| self.initialize_map(config=config) | |
| self.loc_config = config['localization'] | |
| self.viewer = viewer | |
| # options | |
| self.do_refinement = self.loc_config['do_refinement'] | |
| self.refinement_method = self.loc_config['refinement_method'] | |
| self.semantic_matching = self.loc_config['semantic_matching'] | |
| self.do_pre_filtering = self.loc_config['pre_filtering_th'] > 0 | |
| self.pre_filtering_th = self.loc_config['pre_filtering_th'] | |
| def initialize_map(self, config): | |
| n_class = 0 | |
| datasets = config['dataset'] | |
| for name in datasets: | |
| config_path = osp.join(config['config_path'], '{:s}.yaml'.format(name)) | |
| dataset_name = name | |
| with open(config_path, 'r') as f: | |
| scene_config = yaml.load(f, Loader=yaml.Loader) | |
| scenes = scene_config['scenes'] | |
| for sid, scene in enumerate(scenes): | |
| self.scenes.append(name + '/' + scene) | |
| new_config = deepcopy(config) | |
| new_config['dataset_path'] = osp.join(config['dataset_path'], dataset_name, scene) | |
| new_config['landmark_path'] = osp.join(config['landmark_path'], dataset_name, scene) | |
| new_config['n_cluster'] = scene_config[scene]['n_cluster'] | |
| new_config['cluster_mode'] = scene_config[scene]['cluster_mode'] | |
| new_config['cluster_method'] = scene_config[scene]['cluster_method'] | |
| new_config['gt_pose_path'] = scene_config[scene]['gt_pose_path'] | |
| new_config['image_path_prefix'] = scene_config[scene]['image_path_prefix'] | |
| sub_map = SingleMap3D(config=new_config, | |
| matcher=self.matcher, | |
| with_compress=config['localization']['with_compress'], | |
| start_sid=n_class) | |
| self.sub_maps[dataset_name + '/' + scene] = sub_map | |
| n_scene_class = scene_config[scene]['n_cluster'] | |
| self.sid_scene_name = self.sid_scene_name + [dataset_name + '/' + scene for ni in range(n_scene_class)] | |
| self.scene_name_start_sid[dataset_name + '/' + scene] = n_class | |
| n_class = n_class + n_scene_class | |
| # break | |
| print('Load {} sub_maps from {} datasets'.format(len(self.sub_maps), len(datasets))) | |
| def run(self, q_frame: Frame): | |
| show = self.loc_config['show'] | |
| seg_color = generate_color_dic(n_seg=2000) | |
| if show: | |
| cv2.namedWindow('loc', cv2.WINDOW_NORMAL) | |
| q_loc_segs = self.process_segmentations(segs=torch.from_numpy(q_frame.segmentations), | |
| topk=self.loc_config['seg_k']) | |
| q_pred_segs_top1 = q_frame.seg_ids # initial results | |
| q_scene_name = q_frame.scene_name | |
| q_name = q_frame.name | |
| q_full_name = osp.join(q_scene_name, q_name) | |
| q_loc_sids = {} | |
| for v in q_loc_segs: | |
| q_loc_sids[v[0]] = (v[1], v[2]) | |
| query_sids = list(q_loc_sids.keys()) | |
| for i, sid in enumerate(query_sids): | |
| t_start = time.time() | |
| q_kpt_ids = q_loc_sids[sid][0] | |
| print(q_scene_name, q_name, sid) | |
| sid = sid - 1 # start from 0, confused! | |
| pred_scene_name = self.sid_scene_name[sid] | |
| start_seg_id = self.scene_name_start_sid[pred_scene_name] | |
| pred_sid_in_sub_scene = sid - self.scene_name_start_sid[pred_scene_name] | |
| pred_sub_map = self.sub_maps[pred_scene_name] | |
| pred_image_path_prefix = pred_sub_map.image_path_prefix | |
| print('pred/gt scene: {:s}, {:s}, sid: {:d}'.format(pred_scene_name, q_scene_name, pred_sid_in_sub_scene)) | |
| print('{:s}/{:s}, pred: {:s}, sid: {:d}, order: {:d}'.format(q_scene_name, q_name, pred_scene_name, sid, | |
| i)) | |
| if (q_kpt_ids.shape[0] >= self.loc_config['min_kpts'] | |
| and self.semantic_matching | |
| and pred_sub_map.check_semantic_consistency(q_frame=q_frame, | |
| sid=pred_sid_in_sub_scene, | |
| overlap_ratio=0.5)): | |
| semantic_matching = True | |
| else: | |
| q_kpt_ids = np.arange(q_frame.keypoints.shape[0]) | |
| semantic_matching = False | |
| print_text = f'Semantic matching - {semantic_matching}! Query kpts {q_kpt_ids.shape[0]} for {i}th seg {sid}' | |
| print(print_text) | |
| ret = pred_sub_map.localize_with_ref_frame(q_frame=q_frame, | |
| q_kpt_ids=q_kpt_ids, | |
| sid=pred_sid_in_sub_scene, | |
| semantic_matching=semantic_matching) | |
| q_frame.time_loc = q_frame.time_loc + time.time() - t_start # accumulate tracking time | |
| if show: | |
| reference_frame = pred_sub_map.reference_frames[ret['reference_frame_id']] | |
| ref_img = cv2.imread(osp.join(self.config['dataset_path'], pred_scene_name, pred_image_path_prefix, | |
| reference_frame.name)) | |
| q_img_seg = vis_seg_point(img=q_frame.image, kpts=q_frame.keypoints[q_kpt_ids, :2], | |
| segs=q_frame.seg_ids[q_kpt_ids] + 1, | |
| seg_color=seg_color) | |
| matched_points3D_ids = ret['matched_point3D_ids'] | |
| ref_sids = np.array([pred_sub_map.point3Ds[v].seg_id for v in matched_points3D_ids]) + \ | |
| self.scene_name_start_sid[pred_scene_name] + 1 # start from 1 as bg is 0 | |
| ref_img_seg = vis_seg_point(img=ref_img, kpts=ret['matched_ref_keypoints'], segs=ref_sids, | |
| seg_color=seg_color) | |
| q_matched_kpts = ret['matched_keypoints'] | |
| ref_matched_kpts = ret['matched_ref_keypoints'] | |
| img_loc_matching = plot_matches(img1=q_img_seg, img2=ref_img_seg, | |
| pts1=q_matched_kpts, pts2=ref_matched_kpts, | |
| inliers=np.array([True for i in range(q_matched_kpts.shape[0])]), | |
| radius=9, line_thickness=3 | |
| ) | |
| q_frame.image_matching_tmp = img_loc_matching | |
| q_frame.reference_frame_name_tmp = osp.join(self.config['dataset_path'], | |
| pred_scene_name, | |
| pred_image_path_prefix, | |
| reference_frame.name) | |
| # ret['image_matching'] = img_loc_matching | |
| # ret['reference_frame_name'] = osp.join(self.config['dataset_path'], | |
| # pred_scene_name, | |
| # pred_image_path_prefix, | |
| # reference_frame.name) | |
| q_ref_img_matching = np.hstack([resize_img(q_img_seg, nh=512), | |
| resize_img(ref_img_seg, nh=512), | |
| resize_img(img_loc_matching, nh=512)]) | |
| ret['order'] = i | |
| ret['matched_scene_name'] = pred_scene_name | |
| if not ret['success']: | |
| num_matches = ret['matched_keypoints'].shape[0] | |
| num_inliers = ret['num_inliers'] | |
| print_text = f'Localization failed with {num_matches}/{q_kpt_ids.shape[0]} matches and {num_inliers} inliers, order {i}' | |
| print(print_text) | |
| if show: | |
| show_text = 'FAIL! order: {:d}/{:d}-{:d}/{:d}'.format(i, len(q_loc_segs), | |
| num_matches, | |
| q_kpt_ids.shape[0]) | |
| q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], | |
| radius=9 + 2, thickness=2) | |
| q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| q_frame.image_inlier_tmp = q_img_inlier | |
| q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) | |
| cv2.imshow('loc', q_img_loc) | |
| key = cv2.waitKey(self.loc_config['show_time']) | |
| if key == ord('q'): | |
| cv2.destroyAllWindows() | |
| exit(0) | |
| continue | |
| if show: | |
| q_err, t_err = q_frame.compute_pose_error() | |
| num_matches = ret['matched_keypoints'].shape[0] | |
| num_inliers = ret['num_inliers'] | |
| show_text = 'order: {:d}/{:d}, k/m/i: {:d}/{:d}/{:d}'.format( | |
| i, len(q_loc_segs), q_kpt_ids.shape[0], num_matches, num_inliers) | |
| q_img_inlier = vis_inlier(img=q_img_seg, kpts=ret['matched_keypoints'], inliers=ret['inliers'], | |
| radius=9 + 2, thickness=2) | |
| q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 30), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| show_text = 'r_err:{:.2f}, t_err:{:.2f}'.format(q_err, t_err) | |
| q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 80), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| q_frame.image_inlier_tmp = q_img_inlier | |
| q_img_loc = np.hstack([resize_img(q_ref_img_matching, nh=512), resize_img(q_img_inlier, nh=512)]) | |
| cv2.imshow('loc', q_img_loc) | |
| key = cv2.waitKey(self.loc_config['show_time']) | |
| if key == ord('q'): | |
| cv2.destroyAllWindows() | |
| exit(0) | |
| success = self.verify_and_update(q_frame=q_frame, ret=ret) | |
| if not success: | |
| continue | |
| else: | |
| break | |
| if q_frame.tracking_status is None: | |
| print('Failed to find a proper reference frame.') | |
| return False | |
| # do refinement | |
| if not self.do_refinement: | |
| return True | |
| else: | |
| t_start = time.time() | |
| pred_sub_map = self.sub_maps[q_frame.matched_scene_name] | |
| if q_frame.tracking_status is True and np.sum(q_frame.matched_inliers) >= 64: | |
| ret = pred_sub_map.refine_pose(q_frame=q_frame, refinement_method=self.loc_config['refinement_method']) | |
| else: | |
| ret = pred_sub_map.refine_pose(q_frame=q_frame, | |
| refinement_method='matching') # do not trust the pose for projection | |
| q_frame.time_ref = time.time() - t_start | |
| inlier_mask = np.array(ret['inliers']) | |
| q_frame.qvec = ret['qvec'] | |
| q_frame.tvec = ret['tvec'] | |
| q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] | |
| q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] | |
| q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] | |
| q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] | |
| q_frame.matched_sids = ret['matched_sids'][inlier_mask] | |
| q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] | |
| q_frame.refinement_reference_frame_ids = ret['refinement_reference_frame_ids'] | |
| q_frame.reference_frame_id = ret['reference_frame_id'] | |
| q_err, t_err = q_frame.compute_pose_error() | |
| ref_full_name = q_frame.matched_scene_name + '/' + pred_sub_map.reference_frames[ | |
| q_frame.reference_frame_id].name | |
| print_text = 'Localization of {:s} success with inliers {:d}/{:d} with ref_name: {:s}, order: {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( | |
| q_full_name, ret['num_inliers'], len(ret['inliers']), ref_full_name, q_frame.matched_order, q_err, | |
| t_err) | |
| print(print_text) | |
| if show: | |
| q_err, t_err = q_frame.compute_pose_error() | |
| num_matches = ret['matched_keypoints'].shape[0] | |
| num_inliers = ret['num_inliers'] | |
| show_text = 'Ref:{:d}/{:d},r_err:{:.2f}/t_err:{:.2f}'.format(num_matches, num_inliers, q_err, | |
| t_err) | |
| q_img_inlier = cv2.putText(img=q_img_inlier, text=show_text, org=(30, 130), | |
| fontFace=cv2.FONT_HERSHEY_SIMPLEX, fontScale=1, color=(0, 0, 255), | |
| thickness=2, lineType=cv2.LINE_AA) | |
| q_frame.image_inlier = q_img_inlier | |
| return True | |
| def verify_and_update(self, q_frame: Frame, ret: dict): | |
| num_matches = ret['matched_keypoints'].shape[0] | |
| num_inliers = ret['num_inliers'] | |
| if q_frame.matched_keypoints is None or np.sum(q_frame.matched_inliers) < num_inliers: | |
| self.update_query_frame(q_frame=q_frame, ret=ret) | |
| q_err, t_err = q_frame.compute_pose_error(pred_qvec=ret['qvec'], pred_tvec=ret['tvec']) | |
| if num_inliers < self.loc_config['min_inliers']: | |
| print_text = 'Failed due to insufficient {:d} inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( | |
| ret['num_inliers'], ret['order'], q_err, t_err) | |
| print(print_text) | |
| q_frame.tracking_status = False | |
| return False | |
| else: | |
| print_text = 'Succeed! Find {}/{} 2D-3D inliers, order {:d}, q_err: {:.2f}, t_err: {:.2f}'.format( | |
| num_inliers, num_matches, ret['order'], q_err, t_err) | |
| print(print_text) | |
| q_frame.tracking_status = True | |
| return True | |
| def update_query_frame(self, q_frame, ret): | |
| q_frame.matched_scene_name = ret['matched_scene_name'] | |
| q_frame.reference_frame_id = ret['reference_frame_id'] | |
| q_frame.qvec = ret['qvec'] | |
| q_frame.tvec = ret['tvec'] | |
| inlier_mask = np.array(ret['inliers']) | |
| q_frame.matched_keypoints = ret['matched_keypoints'] | |
| q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'] | |
| q_frame.matched_xyzs = ret['matched_xyzs'] | |
| q_frame.matched_point3D_ids = ret['matched_point3D_ids'] | |
| q_frame.matched_sids = ret['matched_sids'] | |
| q_frame.matched_inliers = np.array(ret['inliers']) | |
| q_frame.matched_order = ret['order'] | |
| if q_frame.image_inlier_tmp is not None: | |
| q_frame.image_inlier = deepcopy(q_frame.image_inlier_tmp) | |
| if q_frame.image_matching_tmp is not None: | |
| q_frame.image_matching = deepcopy(q_frame.image_matching_tmp) | |
| if q_frame.reference_frame_name_tmp is not None: | |
| q_frame.reference_frame_name = q_frame.reference_frame_name_tmp | |
| # inlier_mask = np.array(ret['inliers']) | |
| # q_frame.matched_keypoints = ret['matched_keypoints'][inlier_mask] | |
| # q_frame.matched_keypoint_ids = ret['matched_keypoint_ids'][inlier_mask] | |
| # q_frame.matched_xyzs = ret['matched_xyzs'][inlier_mask] | |
| # q_frame.matched_point3D_ids = ret['matched_point3D_ids'][inlier_mask] | |
| # q_frame.matched_sids = ret['matched_sids'][inlier_mask] | |
| # q_frame.matched_inliers = np.array(ret['inliers'])[inlier_mask] | |
| # print('update_query_frame: ', q_frame.matched_keypoint_ids.shape, q_frame.matched_keypoints.shape, | |
| # q_frame.matched_xyzs.shape, q_frame.matched_xyzs.shape, np.sum(q_frame.matched_inliers)) | |
| def process_segmentations(self, segs, topk=10): | |
| pred_values, pred_ids = torch.topk(segs, k=segs.shape[-1], largest=True, dim=-1) # [N, C] | |
| pred_values = pred_values.numpy() | |
| pred_ids = pred_ids.numpy() | |
| out = [] | |
| used_sids = [] | |
| for k in range(segs.shape[-1]): | |
| values_k = pred_values[:, k] | |
| ids_k = pred_ids[:, k] | |
| uids = np.unique(ids_k) | |
| out_k = [] | |
| for sid in uids: | |
| if sid == 0: | |
| continue | |
| if sid in used_sids: | |
| continue | |
| used_sids.append(sid) | |
| ids = np.where(ids_k == sid)[0] | |
| score = np.mean(values_k[ids]) | |
| # score = np.median(values_k[ids]) | |
| # score = 100 - k | |
| # out_k.append((ids.shape[0], sid - 1, ids, score)) | |
| out_k.append((ids.shape[0], sid, ids, score)) | |
| out_k = sorted(out_k, key=lambda item: item[0], reverse=True) | |
| for v in out_k: | |
| out.append((v[1], v[2], v[3])) # [sid, ids, score] | |
| if len(out) >= topk: | |
| return out | |
| return out | |