|
|
|
import contextlib |
|
import warnings |
|
from collections.abc import Generator |
|
|
|
import torch |
|
from torch._C import default_generator |
|
|
|
|
|
def set_rng_state(new_state: torch.Tensor) -> None: |
|
r"""Sets the random number generator state. |
|
|
|
.. note:: This function only works for CPU. For CUDA, please use |
|
:func:`torch.manual_seed`, which works for both CPU and CUDA. |
|
|
|
Args: |
|
new_state (torch.ByteTensor): The desired state |
|
""" |
|
default_generator.set_state(new_state) |
|
|
|
|
|
def get_rng_state() -> torch.Tensor: |
|
r"""Returns the random number generator state as a `torch.ByteTensor`. |
|
|
|
.. note:: The returned state is for the default generator on CPU only. |
|
|
|
See also: :func:`torch.random.fork_rng`. |
|
""" |
|
return default_generator.get_state() |
|
|
|
|
|
def manual_seed(seed) -> torch._C.Generator: |
|
r"""Sets the seed for generating random numbers on all devices. Returns a |
|
`torch.Generator` object. |
|
|
|
Args: |
|
seed (int): The desired seed. Value must be within the inclusive range |
|
`[-0x8000_0000_0000_0000, 0xffff_ffff_ffff_ffff]`. Otherwise, a RuntimeError |
|
is raised. Negative inputs are remapped to positive values with the formula |
|
`0xffff_ffff_ffff_ffff + seed`. |
|
""" |
|
seed = int(seed) |
|
import torch.cuda |
|
|
|
if not torch.cuda._is_in_bad_fork(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
import torch.mps |
|
|
|
if not torch.mps._is_in_bad_fork(): |
|
torch.mps.manual_seed(seed) |
|
|
|
import torch.xpu |
|
|
|
if not torch.xpu._is_in_bad_fork(): |
|
torch.xpu.manual_seed_all(seed) |
|
|
|
_seed_custom_device(seed) |
|
|
|
return default_generator.manual_seed(seed) |
|
|
|
|
|
def seed() -> int: |
|
r"""Sets the seed for generating random numbers to a non-deterministic |
|
random number on all devices. Returns a 64 bit number used to seed the RNG. |
|
""" |
|
seed = default_generator.seed() |
|
import torch.cuda |
|
|
|
if not torch.cuda._is_in_bad_fork(): |
|
torch.cuda.manual_seed_all(seed) |
|
|
|
import torch.mps |
|
|
|
if not torch.mps._is_in_bad_fork(): |
|
torch.mps.manual_seed(seed) |
|
|
|
import torch.xpu |
|
|
|
if not torch.xpu._is_in_bad_fork(): |
|
torch.xpu.manual_seed_all(seed) |
|
|
|
_seed_custom_device(seed) |
|
|
|
return seed |
|
|
|
|
|
def _seed_custom_device(seed) -> None: |
|
r"""Sets the seed to generate random numbers for custom device. |
|
|
|
Args: |
|
seed (int): The desired seed. |
|
|
|
See [Note: support the custom device with privateuse1] |
|
""" |
|
seed = int(seed) |
|
custom_backend_name = torch._C._get_privateuse1_backend_name() |
|
if hasattr(torch, custom_backend_name): |
|
custom_device_mod = getattr(torch, custom_backend_name) |
|
_bad_fork_name = "_is_in_bad_fork" |
|
_seed_all_name = "manual_seed_all" |
|
if hasattr(custom_device_mod, _bad_fork_name) and hasattr( |
|
custom_device_mod, _seed_all_name |
|
): |
|
if not getattr(custom_device_mod, _bad_fork_name)(): |
|
getattr(custom_device_mod, _seed_all_name)(seed) |
|
else: |
|
message = f"Set seed for `{custom_backend_name}` device does not take effect, please add API's " |
|
message += f"`{_bad_fork_name}` and `{_seed_all_name}` to `{custom_backend_name}` device module." |
|
warnings.warn(message, UserWarning, stacklevel=3) |
|
|
|
|
|
def initial_seed() -> int: |
|
r"""Returns the initial seed for generating random numbers as a |
|
Python `long`. |
|
|
|
.. note:: The returned seed is for the default generator on CPU only. |
|
""" |
|
return default_generator.initial_seed() |
|
|
|
|
|
_fork_rng_warned_already = False |
|
|
|
|
|
@contextlib.contextmanager |
|
def fork_rng( |
|
devices=None, |
|
enabled=True, |
|
_caller="fork_rng", |
|
_devices_kw="devices", |
|
device_type="cuda", |
|
) -> Generator: |
|
""" |
|
Forks the RNG, so that when you return, the RNG is reset |
|
to the state that it was previously in. |
|
|
|
Args: |
|
devices (iterable of Device IDs): devices for which to fork |
|
the RNG. CPU RNG state is always forked. By default, :meth:`fork_rng` operates |
|
on all devices, but will emit a warning if your machine has a lot |
|
of devices, since this function will run very slowly in that case. |
|
If you explicitly specify devices, this warning will be suppressed |
|
enabled (bool): if ``False``, the RNG is not forked. This is a convenience |
|
argument for easily disabling the context manager without having |
|
to delete it and unindent your Python code under it. |
|
device_type (str): device type str, default is `cuda`. As for custom device, |
|
see details in [Note: support the custom device with privateuse1] |
|
""" |
|
|
|
if device_type == "meta": |
|
yield |
|
return |
|
|
|
device_type = torch.device(device_type).type |
|
device_mod = getattr(torch, device_type, None) |
|
if device_mod is None: |
|
raise RuntimeError( |
|
f"torch has no module of `{device_type}`, you should register " |
|
+ "a module by `torch._register_device_module`." |
|
) |
|
global _fork_rng_warned_already |
|
|
|
|
|
|
|
|
|
|
|
if not enabled: |
|
yield |
|
return |
|
|
|
if devices is None: |
|
num_devices = device_mod.device_count() |
|
if num_devices > 1 and not _fork_rng_warned_already: |
|
message = ( |
|
f"{device_type.upper()} reports that you have {num_devices} available devices, and " |
|
f"you have used {_caller} without explicitly specifying which devices are being used. " |
|
f"For safety, we initialize *every* {device_type.upper()} device by default, which can " |
|
f"be quite slow if you have a lot of {device_type.upper()}s. If you know that you are only" |
|
f" making use of a few {device_type.upper()} devices, set the environment variable " |
|
f"{device_type.upper()}_VISIBLE_DEVICES or the '{_devices_kw}' keyword argument of {_caller} " |
|
"with the set of devices you are actually using. For example, if you are using CPU only, " |
|
"set device.upper()_VISIBLE_DEVICES= or devices=[]; if you are using device 0 only, " |
|
f"set {device_type.upper()}_VISIBLE_DEVICES=0 or devices=[0]. To initialize all devices " |
|
f"and suppress this warning, set the '{_devices_kw}' keyword argument to " |
|
f"`range(torch.{device_type}.device_count())`." |
|
) |
|
warnings.warn(message) |
|
_fork_rng_warned_already = True |
|
devices = list(range(num_devices)) |
|
else: |
|
|
|
|
|
devices = list(devices) |
|
|
|
cpu_rng_state = torch.get_rng_state() |
|
device_rng_states = [device_mod.get_rng_state(device) for device in devices] |
|
|
|
try: |
|
yield |
|
finally: |
|
torch.set_rng_state(cpu_rng_state) |
|
for device, device_rng_state in zip(devices, device_rng_states): |
|
device_mod.set_rng_state(device_rng_state, device) |
|
|