|
|
|
|
|
"""Wrapper to mimic (parts of) np.random API surface. |
|
|
|
NumPy has strict guarantees on reproducibility etc; here we don't give any. |
|
|
|
Q: default dtype is float64 in numpy |
|
|
|
""" |
|
from __future__ import annotations |
|
|
|
import functools |
|
from math import sqrt |
|
from typing import Optional |
|
|
|
import torch |
|
|
|
from . import _dtypes_impl, _util |
|
from ._normalizations import array_or_scalar, ArrayLike, normalizer |
|
|
|
|
|
__all__ = [ |
|
"seed", |
|
"random_sample", |
|
"sample", |
|
"random", |
|
"rand", |
|
"randn", |
|
"normal", |
|
"choice", |
|
"randint", |
|
"shuffle", |
|
"uniform", |
|
] |
|
|
|
|
|
def use_numpy_random(): |
|
|
|
import torch._dynamo.config as config |
|
|
|
return config.use_numpy_random_stream |
|
|
|
|
|
def deco_stream(func): |
|
@functools.wraps(func) |
|
def inner(*args, **kwds): |
|
if not use_numpy_random(): |
|
return func(*args, **kwds) |
|
else: |
|
import numpy |
|
|
|
from ._ndarray import ndarray |
|
|
|
f = getattr(numpy.random, func.__name__) |
|
|
|
|
|
args = tuple( |
|
arg.tensor.numpy() if isinstance(arg, ndarray) else arg for arg in args |
|
) |
|
kwds = { |
|
key: val.tensor.numpy() if isinstance(val, ndarray) else val |
|
for key, val in kwds.items() |
|
} |
|
|
|
value = f(*args, **kwds) |
|
|
|
|
|
if isinstance(value, numpy.ndarray): |
|
value = ndarray(torch.as_tensor(value)) |
|
|
|
return value |
|
|
|
return inner |
|
|
|
|
|
@deco_stream |
|
def seed(seed=None): |
|
if seed is not None: |
|
torch.random.manual_seed(seed) |
|
|
|
|
|
@deco_stream |
|
def random_sample(size=None): |
|
if size is None: |
|
size = () |
|
dtype = _dtypes_impl.default_dtypes().float_dtype |
|
values = torch.empty(size, dtype=dtype).uniform_() |
|
return array_or_scalar(values, return_scalar=size == ()) |
|
|
|
|
|
def rand(*size): |
|
if size == (): |
|
size = None |
|
return random_sample(size) |
|
|
|
|
|
sample = random_sample |
|
random = random_sample |
|
|
|
|
|
@deco_stream |
|
def uniform(low=0.0, high=1.0, size=None): |
|
if size is None: |
|
size = () |
|
dtype = _dtypes_impl.default_dtypes().float_dtype |
|
values = torch.empty(size, dtype=dtype).uniform_(low, high) |
|
return array_or_scalar(values, return_scalar=size == ()) |
|
|
|
|
|
@deco_stream |
|
def randn(*size): |
|
dtype = _dtypes_impl.default_dtypes().float_dtype |
|
values = torch.randn(size, dtype=dtype) |
|
return array_or_scalar(values, return_scalar=size == ()) |
|
|
|
|
|
@deco_stream |
|
def normal(loc=0.0, scale=1.0, size=None): |
|
if size is None: |
|
size = () |
|
dtype = _dtypes_impl.default_dtypes().float_dtype |
|
values = torch.empty(size, dtype=dtype).normal_(loc, scale) |
|
return array_or_scalar(values, return_scalar=size == ()) |
|
|
|
|
|
@deco_stream |
|
def shuffle(x): |
|
|
|
from ._ndarray import ndarray |
|
|
|
if isinstance(x, torch.Tensor): |
|
tensor = x |
|
elif isinstance(x, ndarray): |
|
tensor = x.tensor |
|
else: |
|
raise NotImplementedError("We do not random.shuffle lists in-place") |
|
|
|
perm = torch.randperm(tensor.shape[0]) |
|
xp = tensor[perm] |
|
tensor.copy_(xp) |
|
|
|
|
|
@deco_stream |
|
def randint(low, high=None, size=None): |
|
if size is None: |
|
size = () |
|
if not isinstance(size, (tuple, list)): |
|
size = (size,) |
|
if high is None: |
|
low, high = 0, low |
|
values = torch.randint(low, high, size=size) |
|
return array_or_scalar(values, int, return_scalar=size == ()) |
|
|
|
|
|
@deco_stream |
|
@normalizer |
|
def choice(a: ArrayLike, size=None, replace=True, p: Optional[ArrayLike] = None): |
|
|
|
if a.numel() == 1: |
|
a = torch.arange(a) |
|
|
|
|
|
|
|
|
|
if size is None: |
|
num_el = 1 |
|
elif _util.is_sequence(size): |
|
num_el = 1 |
|
for el in size: |
|
num_el *= el |
|
else: |
|
num_el = size |
|
|
|
|
|
if p is None: |
|
p = torch.ones_like(a) / a.shape[0] |
|
|
|
|
|
atol = sqrt(torch.finfo(p.dtype).eps) |
|
if abs(p.sum() - 1.0) > atol: |
|
raise ValueError("probabilities do not sum to 1.") |
|
|
|
|
|
indices = torch.multinomial(p, num_el, replacement=replace) |
|
|
|
if _util.is_sequence(size): |
|
indices = indices.reshape(size) |
|
|
|
samples = a[indices] |
|
|
|
return samples |
|
|