File size: 7,781 Bytes
f96995c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 |
from typing import List, Tuple, Optional, Dict
import math
import numpy as np
def get_accumulate_timestamp_idxs(
timestamps: List[float],
start_time: float,
dt: float,
eps:float=1e-5,
next_global_idx: Optional[int]=0,
allow_negative=False
) -> Tuple[List[int], List[int], int]:
"""
For each dt window, choose the first timestamp in the window.
Assumes timestamps sorted. One timestamp might be chosen multiple times due to dropped frames.
next_global_idx should start at 0 normally, and then use the returned next_global_idx.
However, when overwiting previous values are desired, set last_global_idx to None.
Returns:
local_idxs: which index in the given timestamps array to chose from
global_idxs: the global index of each chosen timestamp
next_global_idx: used for next call.
"""
local_idxs = list()
global_idxs = list()
for local_idx, ts in enumerate(timestamps):
# add eps * dt to timestamps so that when ts == start_time + k * dt
# is always recorded as kth element (avoiding floating point errors)
global_idx = math.floor((ts - start_time) / dt + eps)
if (not allow_negative) and (global_idx < 0):
continue
if next_global_idx is None:
next_global_idx = global_idx
n_repeats = max(0, global_idx - next_global_idx + 1)
for i in range(n_repeats):
local_idxs.append(local_idx)
global_idxs.append(next_global_idx + i)
next_global_idx += n_repeats
return local_idxs, global_idxs, next_global_idx
def align_timestamps(
timestamps: List[float],
target_global_idxs: List[int],
start_time: float,
dt: float,
eps:float=1e-5):
if isinstance(target_global_idxs, np.ndarray):
target_global_idxs = target_global_idxs.tolist()
assert len(target_global_idxs) > 0
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
timestamps=timestamps,
start_time=start_time,
dt=dt,
eps=eps,
next_global_idx=target_global_idxs[0],
allow_negative=True
)
if len(global_idxs) > len(target_global_idxs):
# if more steps available, truncate
global_idxs = global_idxs[:len(target_global_idxs)]
local_idxs = local_idxs[:len(target_global_idxs)]
if len(global_idxs) == 0:
import pdb; pdb.set_trace()
for i in range(len(target_global_idxs) - len(global_idxs)):
# if missing, repeat
local_idxs.append(len(timestamps)-1)
global_idxs.append(global_idxs[-1] + 1)
assert global_idxs == target_global_idxs
assert len(local_idxs) == len(global_idxs)
return local_idxs
class TimestampObsAccumulator:
def __init__(self,
start_time: float,
dt: float,
eps: float=1e-5):
self.start_time = start_time
self.dt = dt
self.eps = eps
self.obs_buffer = dict()
self.timestamp_buffer = None
self.next_global_idx = 0
def __len__(self):
return self.next_global_idx
@property
def data(self):
if self.timestamp_buffer is None:
return dict()
result = dict()
for key, value in self.obs_buffer.items():
result[key] = value[:len(self)]
return result
@property
def actual_timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.timestamp_buffer[:len(self)]
@property
def timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.start_time + np.arange(len(self)) * self.dt
def put(self, data: Dict[str, np.ndarray], timestamps: np.ndarray):
"""
data:
key: T,*
"""
local_idxs, global_idxs, self.next_global_idx = get_accumulate_timestamp_idxs(
timestamps=timestamps,
start_time=self.start_time,
dt=self.dt,
eps=self.eps,
next_global_idx=self.next_global_idx
)
if len(global_idxs) > 0:
if self.timestamp_buffer is None:
# first allocation
self.obs_buffer = dict()
for key, value in data.items():
self.obs_buffer[key] = np.zeros_like(value)
self.timestamp_buffer = np.zeros(
(len(timestamps),), dtype=np.float64)
this_max_size = global_idxs[-1] + 1
if this_max_size > len(self.timestamp_buffer):
# reallocate
new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
for key in list(self.obs_buffer.keys()):
new_shape = (new_size,) + self.obs_buffer[key].shape[1:]
self.obs_buffer[key] = np.resize(self.obs_buffer[key], new_shape)
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size))
# write data
for key, value in self.obs_buffer.items():
value[global_idxs] = data[key][local_idxs]
self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
class TimestampActionAccumulator:
def __init__(self,
start_time: float,
dt: float,
eps: float=1e-5):
"""
Different from Obs accumulator, the action accumulator
allows overwriting previous values.
"""
self.start_time = start_time
self.dt = dt
self.eps = eps
self.action_buffer = None
self.timestamp_buffer = None
self.size = 0
def __len__(self):
return self.size
@property
def actions(self):
if self.action_buffer is None:
return np.array([])
return self.action_buffer[:len(self)]
@property
def actual_timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.timestamp_buffer[:len(self)]
@property
def timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.start_time + np.arange(len(self)) * self.dt
def put(self, actions: np.ndarray, timestamps: np.ndarray):
"""
Note: timestamps is the time when the action will be issued,
not when the action will be completed (target_timestamp)
"""
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
timestamps=timestamps,
start_time=self.start_time,
dt=self.dt,
eps=self.eps,
# allows overwriting previous actions
next_global_idx=None
)
if len(global_idxs) > 0:
if self.timestamp_buffer is None:
# first allocation
self.action_buffer = np.zeros_like(actions)
self.timestamp_buffer = np.zeros((len(actions),), dtype=np.float64)
this_max_size = global_idxs[-1] + 1
if this_max_size > len(self.timestamp_buffer):
# reallocate
new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
new_shape = (new_size,) + self.action_buffer.shape[1:]
self.action_buffer = np.resize(self.action_buffer, new_shape)
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size,))
# potentially rewrite old data (as expected)
self.action_buffer[global_idxs] = actions[local_idxs]
self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
self.size = max(self.size, this_max_size)
|