File size: 2,470 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
"""sample."""

import math


class UniformSampleAccumulator:
    def __init__(self, min_samples=None):
        self._samples = min_samples or 64
        # force power of 2 samples
        self._samples = 2 ** int(math.ceil(math.log(self._samples, 2)))
        # target oversample by factor of 2
        self._samples2 = self._samples * 2
        # max size of each buffer
        self._max = self._samples2 // 2
        self._shift = 0
        self._mask = (1 << self._shift) - 1
        self._buckets = int(math.log(self._samples2, 2))
        self._buckets_bits = int(math.log(self._buckets, 2))
        self._buckets_mask = (1 << self._buckets_bits + 1) - 1
        self._buckets_index = 0
        self._bucket = []
        self._index = [0] * self._buckets
        self._count = 0
        self._log2 = [0]

        # pre-allocate buckets
        for _ in range(self._buckets):
            self._bucket.append([0] * self._max)
        # compute integer log2
        self._log2 += [int(math.log(i, 2)) for i in range(1, 2**self._buckets + 1)]

    def _show(self):
        print("=" * 20)  # noqa: T201
        for b in range(self._buckets):
            b = (b + self._buckets_index) % self._buckets
            vals = [self._bucket[b][i] for i in range(self._index[b])]
            print(f"{b}: {vals}")  # noqa: T201

    def add(self, val):
        self._count += 1
        cnt = self._count
        if cnt & self._mask:
            return
        b = cnt >> self._shift
        b = self._log2[b]  # b = int(math.log(b, 2))
        if b >= self._buckets:
            self._index[self._buckets_index] = 0
            self._buckets_index = (self._buckets_index + 1) % self._buckets
            self._shift += 1
            self._mask = (self._mask << 1) | 1
            b += self._buckets - 1
        b = (b + self._buckets_index) % self._buckets
        self._bucket[b][self._index[b]] = val
        self._index[b] += 1

    def get(self):
        full = []
        sampled = []
        # self._show()
        for b in range(self._buckets):
            max_num = 2**b
            b = (b + self._buckets_index) % self._buckets
            modb = self._index[b] // max_num
            for i in range(self._index[b]):
                if not modb or i % modb == 0:
                    sampled.append(self._bucket[b][i])
                full.append(self._bucket[b][i])
        if len(sampled) < self._samples:
            return tuple(full)
        return tuple(sampled)