File size: 1,333 Bytes
f96995c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import numpy as np
def sample_points_from_masks(masks, num_points):
"""
sample points from masks and return its absolute coordinates
Args:
masks: np.array with shape (n, h, w)
num_points: int
Returns:
points: np.array with shape (n, points, 2)
"""
n, h, w = masks.shape
points = []
for i in range(n):
# find the valid mask points
indices = np.argwhere(masks[i] == 1)
# the output format of np.argwhere is (y, x) and the shape is (num_points, 2)
# we should convert it to (x, y)
indices = indices[:, ::-1] # (num_points, [y x]) to (num_points, [x y])
# import pdb; pdb.set_trace()
if len(indices) == 0:
# if there are no valid points, append an empty array
points.append(np.array([]))
continue
# resampling if there's not enough points
if len(indices) < num_points:
sampled_indices = np.random.choice(len(indices), num_points, replace=True)
else:
sampled_indices = np.random.choice(len(indices), num_points, replace=False)
sampled_points = indices[sampled_indices]
points.append(sampled_points)
# convert to np.array
points = np.array(points, dtype=np.float32)
return points |