|
from pathlib import Path |
|
import os |
|
import time |
|
import numpy as np |
|
import torch |
|
import cv2 |
|
import open3d as o3d |
|
from threadpoolctl import threadpool_limits |
|
import multiprocess as mp |
|
from functools import partial |
|
from PIL import Image |
|
import supervision as sv |
|
|
|
from pgnd.utils import get_root |
|
root: Path = get_root(__file__) |
|
|
|
from camera.multi_realsense import MultiRealsense |
|
from camera.single_realsense import SingleRealsense |
|
|
|
from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection |
|
from sam2.build_sam import build_sam2, build_sam2_video_predictor |
|
from sam2.sam2_image_predictor import SAM2ImagePredictor |
|
|
|
from utils.pcd_utils import depth2fgpcd |
|
|
|
|
|
def get_mask_raw(depth, intr, extr, bbox, depth_threshold=[0, 2]): |
|
points = depth2fgpcd(depth, intr).reshape(-1, 3) |
|
mask = np.logical_and((depth > depth_threshold[0]), (depth < depth_threshold[1])) |
|
|
|
points = (np.linalg.inv(extr) @ np.concatenate([points, np.ones((points.shape[0], 1)).astype(np.float32)], axis=1).T).T[:, :3] |
|
mask_bbox = np.logical_and( |
|
np.logical_and(points[:, 0] > bbox[0][0], points[:, 0] < bbox[0][1]), |
|
np.logical_and(points[:, 1] > bbox[1][0], points[:, 1] < bbox[1][1]) |
|
) |
|
mask_bbox = mask_bbox.reshape(depth.shape[0], depth.shape[1]) |
|
mask = np.logical_and(mask, mask_bbox) |
|
return mask |
|
|
|
|
|
def segment_process_func(cameras_output, intrs, extrs, text_prompts, processor, grounding_model, image_predictor, bbox, device, show_annotation=True): |
|
colors_list = [] |
|
depths_list = [] |
|
pts_list = [] |
|
for ck, cv in cameras_output.items(): |
|
|
|
image = cv["color"].copy() |
|
depth = cv["depth"].copy() / 1000.0 |
|
image = Image.fromarray(image) |
|
|
|
|
|
inputs = processor(images=image, text=text_prompts, return_tensors="pt").to(device) |
|
with torch.no_grad(): |
|
outputs = grounding_model(**inputs) |
|
results = processor.post_process_grounded_object_detection( |
|
outputs, |
|
inputs.input_ids, |
|
box_threshold=0.325, |
|
text_threshold=0.3, |
|
target_sizes=[image.size[::-1]] |
|
) |
|
input_boxes = results[0]["boxes"].cpu().numpy() |
|
objects = results[0]["labels"] |
|
|
|
depth_mask = get_mask_raw(depth, intrs[ck], extrs[ck], bbox) |
|
|
|
multi_objs = False |
|
if len(objects) > 1: |
|
objects_masked = [] |
|
input_boxes_masked = [] |
|
if intrs is None or extrs is None: |
|
print("No camera intrinsics and extrinsics provided") |
|
return { |
|
"color": [], |
|
"depth": [], |
|
"pts": [], |
|
} |
|
for i, obj in enumerate(objects): |
|
if obj == '': |
|
continue |
|
box = input_boxes[i].astype(int) |
|
if (box[3] - box[1]) * (box[2] - box[0]) > 500 * 400: |
|
continue |
|
depth_mask_box = depth_mask[box[1]:box[3], box[0]:box[2]] |
|
if depth_mask_box.sum() > 0: |
|
objects_masked.append(obj) |
|
input_boxes_masked.append(box) |
|
objects = objects_masked |
|
input_boxes = input_boxes_masked |
|
if len(objects) == 0: |
|
print("No objects detected") |
|
return { |
|
"color": [], |
|
"depth": [], |
|
"pts": [], |
|
} |
|
elif len(objects) > 1: |
|
multi_objs = True |
|
|
|
image_predictor.set_image(np.array(image.convert("RGB"))) |
|
masks, scores, logits = image_predictor.predict( |
|
point_coords=None, |
|
point_labels=None, |
|
box=input_boxes, |
|
multimask_output=False, |
|
) |
|
if masks.ndim == 3: |
|
pass |
|
elif masks.ndim == 4: |
|
assert multi_objs |
|
masks = masks.squeeze(1) |
|
masks = masks.astype(bool) |
|
|
|
ID_TO_OBJECTS = {i: obj for i, obj in enumerate(objects, start=1)} |
|
object_ids = np.arange(1, len(objects) + 1) |
|
|
|
detections = sv.Detections( |
|
xyxy=sv.mask_to_xyxy(masks), |
|
mask=masks, |
|
class_id=np.array(object_ids, dtype=np.int32), |
|
) |
|
box_annotator = sv.BoxAnnotator() |
|
if show_annotation: |
|
annotated_frame = box_annotator.annotate(scene=np.array(image).astype(np.uint8), detections=detections) |
|
label_annotator = sv.LabelAnnotator() |
|
annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids]) |
|
mask_annotator = sv.MaskAnnotator() |
|
annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections) |
|
colors_list.append(annotated_frame) |
|
else: |
|
colors_list.append(np.array(image)) |
|
|
|
depths_list.append(cv["depth"].copy()) |
|
|
|
masks = np.logical_or.reduce(masks, axis=0, keepdims=True) |
|
masks = np.logical_and(masks, depth_mask) |
|
masks = masks.reshape(-1) |
|
assert masks.shape[0] == depth.shape[0] * depth.shape[1] |
|
points = depth2fgpcd(depth, intrs[ck]).reshape(-1, 3) |
|
points = (np.linalg.inv(extrs[ck]) @ np.concatenate([points, np.ones((points.shape[0], 1)).astype(np.float32)], axis=1).T).T[:, :3] |
|
points = points[masks] |
|
pts_list.append(points) |
|
|
|
return { |
|
"color": colors_list, |
|
"depth": depths_list, |
|
"pts": pts_list, |
|
} |
|
|
|
|
|
class SegmentPerception(mp.Process): |
|
|
|
def __init__( |
|
self, |
|
realsense: MultiRealsense | SingleRealsense, |
|
capture_fps, |
|
record_fps, |
|
record_time, |
|
exp_name=None, |
|
bbox=None, |
|
data_dir="data", |
|
text_prompts="white cotton rope.", |
|
show_annotation=True, |
|
device=None, |
|
verbose=False, |
|
): |
|
super().__init__() |
|
self.verbose = verbose |
|
|
|
self.capture_fps = capture_fps |
|
self.record_fps = record_fps |
|
self.record_time = record_time |
|
self.exp_name = exp_name |
|
self.data_dir = data_dir |
|
self.bbox = bbox |
|
|
|
self.text_prompts = text_prompts |
|
self.show_annotation = show_annotation |
|
|
|
if self.exp_name is None: |
|
assert self.record_fps == 0 |
|
|
|
self.realsense = realsense |
|
self.perception_q = mp.Queue(maxsize=1) |
|
|
|
self.num_cam = len(realsense.cameras.keys()) |
|
self.alive = mp.Value('b', False) |
|
self.record_restart = mp.Value('b', False) |
|
self.record_stop = mp.Value('b', False) |
|
self.do_process = mp.Value('b', True) |
|
|
|
self.intrs = mp.Array('d', [0.0] * 9 * self.num_cam) |
|
self.extrs = mp.Array('d', [0.0] * 16 * self.num_cam) |
|
|
|
def log(self, msg): |
|
if self.verbose: |
|
print(f"\033[92m{self.name}: {msg}\033[0m") |
|
|
|
@property |
|
def can_record(self): |
|
return self.record_fps != 0 |
|
|
|
def update_intrinsics(self, intrs): |
|
self.intrs[:] = intrs.flatten() |
|
|
|
def update_extrinsics(self, extrs): |
|
self.extrs[:] = extrs.flatten() |
|
|
|
def run(self): |
|
|
|
threadpool_limits(1) |
|
cv2.setNumThreads(1) |
|
|
|
realsense = self.realsense |
|
|
|
|
|
capture_fps = self.capture_fps |
|
record_fps = self.record_fps |
|
record_time = self.record_time |
|
|
|
cameras_output = None |
|
recording_frame = float("inf") |
|
record_start_frame = 0 |
|
is_recording = False |
|
timestamps_f = None |
|
|
|
checkpoint = str(root.parent / "weights/sam2/sam2.1_hiera_large.pt") |
|
model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml" |
|
model_id = "IDEA-Research/grounding-dino-tiny" |
|
device = "cuda" if torch.cuda.is_available() else "cpu" |
|
processor = AutoProcessor.from_pretrained(model_id) |
|
grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device) |
|
image_predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint)) |
|
|
|
process_func = partial( |
|
segment_process_func, |
|
text_prompts=self.text_prompts, |
|
processor=processor, |
|
grounding_model=grounding_model, |
|
image_predictor=image_predictor, |
|
bbox=self.bbox, |
|
device=device, |
|
show_annotation=self.show_annotation, |
|
) |
|
|
|
while self.alive.value: |
|
try: |
|
if not self.do_process.value: |
|
if not self.perception_q.empty(): |
|
self.perception_q.get() |
|
time.sleep(1) |
|
continue |
|
cameras_output = realsense.get(out=cameras_output) |
|
get_time = time.time() |
|
timestamps = [cameras_output[i]['timestamp'].item() for i in range(self.num_cam)] |
|
if is_recording and not all([abs(timestamps[i] - timestamps[i+1]) < 0.05 for i in range(self.num_cam - 1)]): |
|
print(f"Captured at different timestamps: {[f'{x:.2f}' for x in timestamps]}") |
|
|
|
|
|
process_start_time = get_time |
|
|
|
intrs = np.frombuffer(self.intrs.get_obj()).reshape((self.num_cam, 3, 3)) |
|
extrs = np.frombuffer(self.extrs.get_obj()).reshape((self.num_cam, 4, 4)) |
|
|
|
if intrs.sum() == 0 or extrs.sum() == 0: |
|
print("No camera intrinsics and extrinsics provided") |
|
time.sleep(1) |
|
continue |
|
|
|
process_out = process_func(cameras_output, intrs, extrs) |
|
self.log(f"process time: {time.time() - process_start_time}") |
|
|
|
if not self.perception_q.full(): |
|
self.perception_q.put(process_out) |
|
|
|
except BaseException as e: |
|
print("Perception error: ", e.with_traceback()) |
|
break |
|
|
|
if self.can_record: |
|
if timestamps_f is not None and not timestamps_f.closed: |
|
timestamps_f.close() |
|
finish_time = time.time() |
|
self.stop() |
|
print("Perception process stopped") |
|
|
|
|
|
def start(self): |
|
self.alive.value = True |
|
super().start() |
|
|
|
def stop(self): |
|
self.alive.value = False |
|
self.perception_q.close() |
|
|
|
def set_record_start(self): |
|
if self.record_fps == 0: |
|
print("record disabled because record_fps is 0") |
|
assert self.record_restart.value == False |
|
else: |
|
self.record_restart.value = True |
|
print("record restart cmd received") |
|
|
|
def set_record_stop(self): |
|
if self.record_fps == 0: |
|
print("record disabled because record_fps is 0") |
|
assert self.record_stop.value == False |
|
else: |
|
self.record_stop.value = True |
|
print("record stop cmd received") |
|
|