kywind
update
f96995c
from typing import List, Tuple, Optional, Dict
import math
import numpy as np
def get_accumulate_timestamp_idxs(
timestamps: List[float],
start_time: float,
dt: float,
eps:float=1e-5,
next_global_idx: Optional[int]=0,
allow_negative=False
) -> Tuple[List[int], List[int], int]:
"""
For each dt window, choose the first timestamp in the window.
Assumes timestamps sorted. One timestamp might be chosen multiple times due to dropped frames.
next_global_idx should start at 0 normally, and then use the returned next_global_idx.
However, when overwiting previous values are desired, set last_global_idx to None.
Returns:
local_idxs: which index in the given timestamps array to chose from
global_idxs: the global index of each chosen timestamp
next_global_idx: used for next call.
"""
local_idxs = list()
global_idxs = list()
for local_idx, ts in enumerate(timestamps):
# add eps * dt to timestamps so that when ts == start_time + k * dt
# is always recorded as kth element (avoiding floating point errors)
global_idx = math.floor((ts - start_time) / dt + eps)
if (not allow_negative) and (global_idx < 0):
continue
if next_global_idx is None:
next_global_idx = global_idx
n_repeats = max(0, global_idx - next_global_idx + 1)
for i in range(n_repeats):
local_idxs.append(local_idx)
global_idxs.append(next_global_idx + i)
next_global_idx += n_repeats
return local_idxs, global_idxs, next_global_idx
def align_timestamps(
timestamps: List[float],
target_global_idxs: List[int],
start_time: float,
dt: float,
eps:float=1e-5):
if isinstance(target_global_idxs, np.ndarray):
target_global_idxs = target_global_idxs.tolist()
assert len(target_global_idxs) > 0
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
timestamps=timestamps,
start_time=start_time,
dt=dt,
eps=eps,
next_global_idx=target_global_idxs[0],
allow_negative=True
)
if len(global_idxs) > len(target_global_idxs):
# if more steps available, truncate
global_idxs = global_idxs[:len(target_global_idxs)]
local_idxs = local_idxs[:len(target_global_idxs)]
if len(global_idxs) == 0:
import pdb; pdb.set_trace()
for i in range(len(target_global_idxs) - len(global_idxs)):
# if missing, repeat
local_idxs.append(len(timestamps)-1)
global_idxs.append(global_idxs[-1] + 1)
assert global_idxs == target_global_idxs
assert len(local_idxs) == len(global_idxs)
return local_idxs
class TimestampObsAccumulator:
def __init__(self,
start_time: float,
dt: float,
eps: float=1e-5):
self.start_time = start_time
self.dt = dt
self.eps = eps
self.obs_buffer = dict()
self.timestamp_buffer = None
self.next_global_idx = 0
def __len__(self):
return self.next_global_idx
@property
def data(self):
if self.timestamp_buffer is None:
return dict()
result = dict()
for key, value in self.obs_buffer.items():
result[key] = value[:len(self)]
return result
@property
def actual_timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.timestamp_buffer[:len(self)]
@property
def timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.start_time + np.arange(len(self)) * self.dt
def put(self, data: Dict[str, np.ndarray], timestamps: np.ndarray):
"""
data:
key: T,*
"""
local_idxs, global_idxs, self.next_global_idx = get_accumulate_timestamp_idxs(
timestamps=timestamps,
start_time=self.start_time,
dt=self.dt,
eps=self.eps,
next_global_idx=self.next_global_idx
)
if len(global_idxs) > 0:
if self.timestamp_buffer is None:
# first allocation
self.obs_buffer = dict()
for key, value in data.items():
self.obs_buffer[key] = np.zeros_like(value)
self.timestamp_buffer = np.zeros(
(len(timestamps),), dtype=np.float64)
this_max_size = global_idxs[-1] + 1
if this_max_size > len(self.timestamp_buffer):
# reallocate
new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
for key in list(self.obs_buffer.keys()):
new_shape = (new_size,) + self.obs_buffer[key].shape[1:]
self.obs_buffer[key] = np.resize(self.obs_buffer[key], new_shape)
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size))
# write data
for key, value in self.obs_buffer.items():
value[global_idxs] = data[key][local_idxs]
self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
class TimestampActionAccumulator:
def __init__(self,
start_time: float,
dt: float,
eps: float=1e-5):
"""
Different from Obs accumulator, the action accumulator
allows overwriting previous values.
"""
self.start_time = start_time
self.dt = dt
self.eps = eps
self.action_buffer = None
self.timestamp_buffer = None
self.size = 0
def __len__(self):
return self.size
@property
def actions(self):
if self.action_buffer is None:
return np.array([])
return self.action_buffer[:len(self)]
@property
def actual_timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.timestamp_buffer[:len(self)]
@property
def timestamps(self):
if self.timestamp_buffer is None:
return np.array([])
return self.start_time + np.arange(len(self)) * self.dt
def put(self, actions: np.ndarray, timestamps: np.ndarray):
"""
Note: timestamps is the time when the action will be issued,
not when the action will be completed (target_timestamp)
"""
local_idxs, global_idxs, _ = get_accumulate_timestamp_idxs(
timestamps=timestamps,
start_time=self.start_time,
dt=self.dt,
eps=self.eps,
# allows overwriting previous actions
next_global_idx=None
)
if len(global_idxs) > 0:
if self.timestamp_buffer is None:
# first allocation
self.action_buffer = np.zeros_like(actions)
self.timestamp_buffer = np.zeros((len(actions),), dtype=np.float64)
this_max_size = global_idxs[-1] + 1
if this_max_size > len(self.timestamp_buffer):
# reallocate
new_size = max(this_max_size, len(self.timestamp_buffer) * 2)
new_shape = (new_size,) + self.action_buffer.shape[1:]
self.action_buffer = np.resize(self.action_buffer, new_shape)
self.timestamp_buffer = np.resize(self.timestamp_buffer, (new_size,))
# potentially rewrite old data (as expected)
self.action_buffer[global_idxs] = actions[local_idxs]
self.timestamp_buffer[global_idxs] = timestamps[local_idxs]
self.size = max(self.size, this_max_size)