File size: 11,213 Bytes
f96995c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
from pathlib import Path
import os
import time
import numpy as np
import torch
import cv2
import open3d as o3d
from threadpoolctl import threadpool_limits
import multiprocess as mp
from functools import partial
from PIL import Image
import supervision as sv

from pgnd.utils import get_root
root: Path = get_root(__file__)

from camera.multi_realsense import MultiRealsense
from camera.single_realsense import SingleRealsense

from transformers import AutoProcessor, AutoModelForZeroShotObjectDetection 
from sam2.build_sam import build_sam2, build_sam2_video_predictor
from sam2.sam2_image_predictor import SAM2ImagePredictor

from utils.pcd_utils import depth2fgpcd


def get_mask_raw(depth, intr, extr, bbox, depth_threshold=[0, 2]):
    points = depth2fgpcd(depth, intr).reshape(-1, 3)
    mask = np.logical_and((depth > depth_threshold[0]), (depth < depth_threshold[1]))  # (H, W)

    points = (np.linalg.inv(extr) @ np.concatenate([points, np.ones((points.shape[0], 1)).astype(np.float32)], axis=1).T).T[:, :3]  # (N, 3)
    mask_bbox = np.logical_and(
        np.logical_and(points[:, 0] > bbox[0][0], points[:, 0] < bbox[0][1]),
        np.logical_and(points[:, 1] > bbox[1][0], points[:, 1] < bbox[1][1])
    )  # does not include z axis
    mask_bbox = mask_bbox.reshape(depth.shape[0], depth.shape[1])
    mask = np.logical_and(mask, mask_bbox)
    return mask


def segment_process_func(cameras_output, intrs, extrs, text_prompts, processor, grounding_model, image_predictor, bbox, device, show_annotation=True):
    colors_list = []
    depths_list = []
    pts_list = []
    for ck, cv in cameras_output.items():

        image = cv["color"].copy()
        depth = cv["depth"].copy() / 1000.0
        image = Image.fromarray(image)

        # ground
        inputs = processor(images=image, text=text_prompts, return_tensors="pt").to(device)
        with torch.no_grad():
            outputs = grounding_model(**inputs)
        results = processor.post_process_grounded_object_detection(
            outputs,
            inputs.input_ids,
            box_threshold=0.325,
            text_threshold=0.3,
            target_sizes=[image.size[::-1]]
        )
        input_boxes = results[0]["boxes"].cpu().numpy()
        objects = results[0]["labels"]

        depth_mask = get_mask_raw(depth, intrs[ck], extrs[ck], bbox)

        multi_objs = False
        if len(objects) > 1:
            objects_masked = []
            input_boxes_masked = []
            if intrs is None or extrs is None:
                print("No camera intrinsics and extrinsics provided")
                return {
                    "color": [],
                    "depth": [],
                    "pts": [],
                }
            for i, obj in enumerate(objects):
                if obj == '':
                    continue
                box = input_boxes[i].astype(int)
                if (box[3] - box[1]) * (box[2] - box[0]) > 500 * 400:
                    continue
                depth_mask_box = depth_mask[box[1]:box[3], box[0]:box[2]]
                if depth_mask_box.sum() > 0:
                    objects_masked.append(obj)
                    input_boxes_masked.append(box)
            objects = objects_masked
            input_boxes = input_boxes_masked
            if len(objects) == 0:
                print("No objects detected")
                return {
                    "color": [],
                    "depth": [],
                    "pts": [],
                }
            elif len(objects) > 1:
                multi_objs = True

        image_predictor.set_image(np.array(image.convert("RGB")))
        masks, scores, logits = image_predictor.predict(
            point_coords=None,
            point_labels=None,
            box=input_boxes,
            multimask_output=False,
        )
        if masks.ndim == 3:
            pass
        elif masks.ndim == 4:
            assert multi_objs
            masks = masks.squeeze(1)
        masks = masks.astype(bool)

        ID_TO_OBJECTS = {i: obj for i, obj in enumerate(objects, start=1)}
        object_ids = np.arange(1, len(objects) + 1)

        detections = sv.Detections(
            xyxy=sv.mask_to_xyxy(masks),  # (n, 4)
            mask=masks, # (n, h, w)
            class_id=np.array(object_ids, dtype=np.int32),
        )
        box_annotator = sv.BoxAnnotator()
        if show_annotation:
            annotated_frame = box_annotator.annotate(scene=np.array(image).astype(np.uint8), detections=detections)
            label_annotator = sv.LabelAnnotator()
            annotated_frame = label_annotator.annotate(annotated_frame, detections=detections, labels=[ID_TO_OBJECTS[i] for i in object_ids])
            mask_annotator = sv.MaskAnnotator()
            annotated_frame = mask_annotator.annotate(scene=annotated_frame, detections=detections)
            colors_list.append(annotated_frame)
        else:
            colors_list.append(np.array(image))

        depths_list.append(cv["depth"].copy())

        masks = np.logical_or.reduce(masks, axis=0, keepdims=True)
        masks = np.logical_and(masks, depth_mask)
        masks = masks.reshape(-1)
        assert masks.shape[0] == depth.shape[0] * depth.shape[1]
        points = depth2fgpcd(depth, intrs[ck]).reshape(-1, 3)
        points = (np.linalg.inv(extrs[ck]) @ np.concatenate([points, np.ones((points.shape[0], 1)).astype(np.float32)], axis=1).T).T[:, :3]  # (N, 3)
        points = points[masks]
        pts_list.append(points)

    return {
        "color": colors_list,
        "depth": depths_list,
        "pts": pts_list,
    }


class SegmentPerception(mp.Process):

    def __init__(
        self,
        realsense: MultiRealsense | SingleRealsense, 
        capture_fps, 
        record_fps, 
        record_time,
        exp_name=None,
        bbox=None,
        data_dir="data",
        text_prompts="white cotton rope.",
        show_annotation=True,
        device=None,
        verbose=False,
    ):
        super().__init__()
        self.verbose = verbose

        self.capture_fps = capture_fps
        self.record_fps = record_fps
        self.record_time = record_time
        self.exp_name = exp_name
        self.data_dir = data_dir
        self.bbox = bbox

        self.text_prompts = text_prompts
        self.show_annotation = show_annotation

        if self.exp_name is None:
            assert self.record_fps == 0

        self.realsense = realsense
        self.perception_q = mp.Queue(maxsize=1)

        self.num_cam = len(realsense.cameras.keys())
        self.alive = mp.Value('b', False)
        self.record_restart = mp.Value('b', False)
        self.record_stop = mp.Value('b', False)
        self.do_process = mp.Value('b', True)

        self.intrs = mp.Array('d', [0.0] * 9 * self.num_cam)
        self.extrs = mp.Array('d', [0.0] * 16 * self.num_cam)

    def log(self, msg):
        if self.verbose:
            print(f"\033[92m{self.name}: {msg}\033[0m")

    @property
    def can_record(self):
        return self.record_fps != 0
    
    def update_intrinsics(self, intrs):
        self.intrs[:] = intrs.flatten()
    
    def update_extrinsics(self, extrs):
        self.extrs[:] = extrs.flatten()

    def run(self):
        # limit threads
        threadpool_limits(1)
        cv2.setNumThreads(1)

        realsense = self.realsense

        # i = self.index
        capture_fps = self.capture_fps
        record_fps = self.record_fps
        record_time = self.record_time

        cameras_output = None
        recording_frame = float("inf")  # local record step index (since current record start), record fps
        record_start_frame = 0  # global step index (since process start), capture fps
        is_recording = False  # recording state flag
        timestamps_f = None

        checkpoint = str(root.parent / "weights/sam2/sam2.1_hiera_large.pt")
        model_cfg = "configs/sam2.1/sam2.1_hiera_l.yaml"
        model_id = "IDEA-Research/grounding-dino-tiny"
        device = "cuda" if torch.cuda.is_available() else "cpu"
        processor = AutoProcessor.from_pretrained(model_id)
        grounding_model = AutoModelForZeroShotObjectDetection.from_pretrained(model_id).to(device)
        image_predictor = SAM2ImagePredictor(build_sam2(model_cfg, checkpoint))
        
        process_func = partial(
            segment_process_func,
            text_prompts=self.text_prompts,
            processor=processor,
            grounding_model=grounding_model,
            image_predictor=image_predictor,
            bbox=self.bbox,
            device=device,
            show_annotation=self.show_annotation,
        )

        while self.alive.value:
            try: 
                if not self.do_process.value:
                    if not self.perception_q.empty():
                        self.perception_q.get()
                    time.sleep(1)
                    continue
                cameras_output = realsense.get(out=cameras_output)
                get_time = time.time()
                timestamps = [cameras_output[i]['timestamp'].item() for i in range(self.num_cam)]  # type: ignore
                if is_recording and not all([abs(timestamps[i] - timestamps[i+1]) < 0.05 for i in range(self.num_cam - 1)]):
                    print(f"Captured at different timestamps: {[f'{x:.2f}' for x in timestamps]}")

                # treat captured time and record time as the same
                process_start_time = get_time

                intrs = np.frombuffer(self.intrs.get_obj()).reshape((self.num_cam, 3, 3))
                extrs = np.frombuffer(self.extrs.get_obj()).reshape((self.num_cam, 4, 4))

                if intrs.sum() == 0 or extrs.sum() == 0:
                    print("No camera intrinsics and extrinsics provided")
                    time.sleep(1)
                    continue

                process_out = process_func(cameras_output, intrs, extrs)
                self.log(f"process time: {time.time() - process_start_time}")
            
                if not self.perception_q.full():
                    self.perception_q.put(process_out)

            except BaseException as e:
                print("Perception error: ", e.with_traceback())
                break

        if self.can_record:
            if timestamps_f is not None and not timestamps_f.closed:
                timestamps_f.close()
            finish_time = time.time()
        self.stop()
        print("Perception process stopped")


    def start(self):
        self.alive.value = True
        super().start()

    def stop(self):
        self.alive.value = False
        self.perception_q.close()
    
    def set_record_start(self):
        if self.record_fps == 0:
            print("record disabled because record_fps is 0")
            assert self.record_restart.value == False
        else:
            self.record_restart.value = True
            print("record restart cmd received")

    def set_record_stop(self):
        if self.record_fps == 0:
            print("record disabled because record_fps is 0")
            assert self.record_stop.value == False
        else:
            self.record_stop.value = True
            print("record stop cmd received")