File size: 6,220 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 |
# Copyright 2022 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import random
from typing import Optional, Union
import numpy as np
import torch
from ..state import AcceleratorState
from .constants import CUDA_DISTRIBUTED_TYPES
from .dataclasses import DistributedType, RNGType
from .imports import (
is_hpu_available,
is_mlu_available,
is_musa_available,
is_npu_available,
is_sdaa_available,
is_torch_xla_available,
is_xpu_available,
)
if is_torch_xla_available():
import torch_xla.core.xla_model as xm
def set_seed(seed: int, device_specific: bool = False, deterministic: bool = False):
"""
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
Args:
seed (`int`):
The seed to set.
device_specific (`bool`, *optional*, defaults to `False`):
Whether to differ the seed on each device slightly with `self.process_index`.
deterministic (`bool`, *optional*, defaults to `False`):
Whether to use deterministic algorithms where available. Can slow down training.
"""
if device_specific:
seed += AcceleratorState().process_index
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if is_xpu_available():
torch.xpu.manual_seed_all(seed)
elif is_npu_available():
torch.npu.manual_seed_all(seed)
elif is_mlu_available():
torch.mlu.manual_seed_all(seed)
elif is_sdaa_available():
torch.sdaa.manual_seed_all(seed)
elif is_musa_available():
torch.musa.manual_seed_all(seed)
elif is_hpu_available():
torch.hpu.manual_seed_all(seed)
else:
torch.cuda.manual_seed_all(seed)
# ^^ safe to call this function even if cuda is not available
if is_torch_xla_available():
xm.set_rng_state(seed)
if deterministic:
torch.use_deterministic_algorithms(True)
def synchronize_rng_state(rng_type: Optional[RNGType] = None, generator: Optional[torch.Generator] = None):
# Get the proper rng state
if rng_type == RNGType.TORCH:
rng_state = torch.get_rng_state()
elif rng_type == RNGType.CUDA:
rng_state = torch.cuda.get_rng_state()
elif rng_type == RNGType.XLA:
assert is_torch_xla_available(), "Can't synchronize XLA seeds as torch_xla is unavailable."
rng_state = torch.tensor(xm.get_rng_state())
elif rng_type == RNGType.NPU:
assert is_npu_available(), "Can't synchronize NPU seeds on an environment without NPUs."
rng_state = torch.npu.get_rng_state()
elif rng_type == RNGType.MLU:
assert is_mlu_available(), "Can't synchronize MLU seeds on an environment without MLUs."
rng_state = torch.mlu.get_rng_state()
elif rng_type == RNGType.SDAA:
assert is_sdaa_available(), "Can't synchronize SDAA seeds on an environment without SDAAs."
rng_state = torch.sdaa.get_rng_state()
elif rng_type == RNGType.MUSA:
assert is_musa_available(), "Can't synchronize MUSA seeds on an environment without MUSAs."
rng_state = torch.musa.get_rng_state()
elif rng_type == RNGType.XPU:
assert is_xpu_available(), "Can't synchronize XPU seeds on an environment without XPUs."
rng_state = torch.xpu.get_rng_state()
elif rng_type == RNGType.HPU:
assert is_hpu_available(), "Can't synchronize HPU seeds on an environment without HPUs."
rng_state = torch.hpu.get_rng_state()
elif rng_type == RNGType.GENERATOR:
assert generator is not None, "Need a generator to synchronize its seed."
rng_state = generator.get_state()
# Broadcast the rng state from device 0 to other devices
state = AcceleratorState()
if state.distributed_type == DistributedType.XLA:
rng_state = rng_state.to(xm.xla_device())
xm.collective_broadcast([rng_state])
xm.mark_step()
rng_state = rng_state.cpu()
elif (
state.distributed_type in CUDA_DISTRIBUTED_TYPES
or state.distributed_type == DistributedType.MULTI_MLU
or state.distributed_type == DistributedType.MULTI_SDAA
or state.distributed_type == DistributedType.MULTI_MUSA
or state.distributed_type == DistributedType.MULTI_NPU
or state.distributed_type == DistributedType.MULTI_XPU
or state.distributed_type == DistributedType.MULTI_HPU
):
rng_state = rng_state.to(state.device)
torch.distributed.broadcast(rng_state, 0)
rng_state = rng_state.cpu()
elif state.distributed_type == DistributedType.MULTI_CPU:
torch.distributed.broadcast(rng_state, 0)
# Set the broadcast rng state
if rng_type == RNGType.TORCH:
torch.set_rng_state(rng_state)
elif rng_type == RNGType.CUDA:
torch.cuda.set_rng_state(rng_state)
elif rng_type == RNGType.NPU:
torch.npu.set_rng_state(rng_state)
elif rng_type == RNGType.MLU:
torch.mlu.set_rng_state(rng_state)
elif rng_type == RNGType.SDAA:
torch.sdaa.set_rng_state(rng_state)
elif rng_type == RNGType.MUSA:
torch.musa.set_rng_state(rng_state)
elif rng_type == RNGType.XPU:
torch.xpu.set_rng_state(rng_state)
elif rng_state == RNGType.HPU:
torch.hpu.set_rng_state(rng_state)
elif rng_type == RNGType.XLA:
xm.set_rng_state(rng_state.item())
elif rng_type == RNGType.GENERATOR:
generator.set_state(rng_state)
def synchronize_rng_states(rng_types: list[Union[str, RNGType]], generator: Optional[torch.Generator] = None):
for rng_type in rng_types:
synchronize_rng_state(RNGType(rng_type), generator=generator)
|