File size: 7,948 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 |
# mypy: allow-untyped-defs
from typing import Optional
import torch
class SobolEngine:
r"""
The :class:`torch.quasirandom.SobolEngine` is an engine for generating
(scrambled) Sobol sequences. Sobol sequences are an example of low
discrepancy quasi-random sequences.
This implementation of an engine for Sobol sequences is capable of
sampling sequences up to a maximum dimension of 21201. It uses direction
numbers from https://web.maths.unsw.edu.au/~fkuo/sobol/ obtained using the
search criterion D(6) up to the dimension 21201. This is the recommended
choice by the authors.
References:
- Art B. Owen. Scrambling Sobol and Niederreiter-Xing points.
Journal of Complexity, 14(4):466-489, December 1998.
- I. M. Sobol. The distribution of points in a cube and the accurate
evaluation of integrals.
Zh. Vychisl. Mat. i Mat. Phys., 7:784-802, 1967.
Args:
dimension (Int): The dimensionality of the sequence to be drawn
scramble (bool, optional): Setting this to ``True`` will produce
scrambled Sobol sequences. Scrambling is
capable of producing better Sobol
sequences. Default: ``False``.
seed (Int, optional): This is the seed for the scrambling. The seed
of the random number generator is set to this,
if specified. Otherwise, it uses a random seed.
Default: ``None``
Examples::
>>> # xdoctest: +SKIP("unseeded random state")
>>> soboleng = torch.quasirandom.SobolEngine(dimension=5)
>>> soboleng.draw(3)
tensor([[0.0000, 0.0000, 0.0000, 0.0000, 0.0000],
[0.5000, 0.5000, 0.5000, 0.5000, 0.5000],
[0.7500, 0.2500, 0.2500, 0.2500, 0.7500]])
"""
MAXBIT = 30
MAXDIM = 21201
def __init__(self, dimension, scramble=False, seed=None):
if dimension > self.MAXDIM or dimension < 1:
raise ValueError(
"Supported range of dimensionality "
f"for SobolEngine is [1, {self.MAXDIM}]"
)
self.seed = seed
self.scramble = scramble
self.dimension = dimension
cpu = torch.device("cpu")
self.sobolstate = torch.zeros(
dimension, self.MAXBIT, device=cpu, dtype=torch.long
)
torch._sobol_engine_initialize_state_(self.sobolstate, self.dimension)
if not self.scramble:
self.shift = torch.zeros(self.dimension, device=cpu, dtype=torch.long)
else:
self._scramble()
self.quasi = self.shift.clone(memory_format=torch.contiguous_format)
self._first_point = (self.quasi / 2**self.MAXBIT).reshape(1, -1)
self.num_generated = 0
def draw(
self,
n: int = 1,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`n` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
of the result is :math:`(n, dimension)`.
Args:
n (Int, optional): The length of sequence of points to draw.
Default: 1
out (Tensor, optional): The output tensor
dtype (:class:`torch.dtype`, optional): the desired data type of the
returned tensor.
Default: ``None``
"""
if dtype is None:
dtype = torch.get_default_dtype()
if self.num_generated == 0:
if n == 1:
result = self._first_point.to(dtype)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi,
n - 1,
self.sobolstate,
self.dimension,
self.num_generated,
dtype=dtype,
)
result = torch.cat((self._first_point.to(dtype), result), dim=-2)
else:
result, self.quasi = torch._sobol_engine_draw(
self.quasi,
n,
self.sobolstate,
self.dimension,
self.num_generated - 1,
dtype=dtype,
)
self.num_generated += n
if out is not None:
out.resize_as_(result).copy_(result)
return out
return result
def draw_base2(
self,
m: int,
out: Optional[torch.Tensor] = None,
dtype: Optional[torch.dtype] = None,
) -> torch.Tensor:
r"""
Function to draw a sequence of :attr:`2**m` points from a Sobol sequence.
Note that the samples are dependent on the previous samples. The size
of the result is :math:`(2**m, dimension)`.
Args:
m (Int): The (base2) exponent of the number of points to draw.
out (Tensor, optional): The output tensor
dtype (:class:`torch.dtype`, optional): the desired data type of the
returned tensor.
Default: ``None``
"""
n = 2**m
total_n = self.num_generated + n
if not (total_n & (total_n - 1) == 0):
raise ValueError(
"The balance properties of Sobol' points require "
f"n to be a power of 2. {self.num_generated} points have been "
f"previously generated, then: n={self.num_generated}+2**{m}={total_n}. "
"If you still want to do this, please use "
"'SobolEngine.draw()' instead."
)
return self.draw(n=n, out=out, dtype=dtype)
def reset(self):
r"""
Function to reset the ``SobolEngine`` to base state.
"""
self.quasi.copy_(self.shift)
self.num_generated = 0
return self
def fast_forward(self, n):
r"""
Function to fast-forward the state of the ``SobolEngine`` by
:attr:`n` steps. This is equivalent to drawing :attr:`n` samples
without using the samples.
Args:
n (Int): The number of steps to fast-forward by.
"""
if self.num_generated == 0:
torch._sobol_engine_ff_(
self.quasi, n - 1, self.sobolstate, self.dimension, self.num_generated
)
else:
torch._sobol_engine_ff_(
self.quasi, n, self.sobolstate, self.dimension, self.num_generated - 1
)
self.num_generated += n
return self
def _scramble(self):
g: Optional[torch.Generator] = None
if self.seed is not None:
g = torch.Generator()
g.manual_seed(self.seed)
cpu = torch.device("cpu")
# Generate shift vector
shift_ints = torch.randint(
2, (self.dimension, self.MAXBIT), device=cpu, generator=g
)
self.shift = torch.mv(
shift_ints, torch.pow(2, torch.arange(0, self.MAXBIT, device=cpu))
)
# Generate lower triangular matrices (stacked across dimensions)
ltm_dims = (self.dimension, self.MAXBIT, self.MAXBIT)
ltm = torch.randint(2, ltm_dims, device=cpu, generator=g).tril()
torch._sobol_engine_scramble_(self.sobolstate, ltm, self.dimension)
def __repr__(self):
fmt_string = [f"dimension={self.dimension}"]
if self.scramble:
fmt_string += ["scramble=True"]
if self.seed is not None:
fmt_string += [f"seed={self.seed}"]
return self.__class__.__name__ + "(" + ", ".join(fmt_string) + ")"
|