|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import contextlib |
|
from typing import Dict, Iterator, Set, Union |
|
|
|
import torch |
|
from torch.cuda import _lazy_call |
|
from torch.utils.checkpoint import detach_variable |
|
|
|
from .initialize import get_data_parallel_rank, get_model_parallel_rank |
|
|
|
|
|
_MODEL_PARALLEL_RNG_TRACKER_NAME = "model-parallel-rng" |
|
|
|
|
|
def _set_cuda_rng_state(new_state: torch.ByteTensor, device: Union[int, str, torch.device] = -1) -> None: |
|
"""Sets the random number generator state of the current GPU. |
|
|
|
Arguments: |
|
new_state (torch.ByteTensor): The desired state |
|
This function is adapted from PyTorch repo (torch.cuda.set_rng_state) |
|
with a single change: the input state is not cloned. Cloning caused |
|
major performance issues for +4 GPU cases. |
|
""" |
|
if device == -1: |
|
device = torch.device("cuda") |
|
elif isinstance(device, str): |
|
device = torch.device(device) |
|
elif isinstance(device, int): |
|
device = torch.device("cuda", device) |
|
|
|
def cb() -> None: |
|
idx = device.index |
|
if idx is None: |
|
idx = torch.cuda.current_device() |
|
default_generator = torch.cuda.default_generators[idx] |
|
default_generator.set_state(new_state) |
|
|
|
_lazy_call(cb) |
|
|
|
|
|
class CudaRNGStatesTracker: |
|
"""Tracker for the cuda RNG states. |
|
|
|
Using the `add` method, a cuda rng state is initialized based on |
|
the input `seed` and is assigned to `name`. Later, by forking the |
|
rng state, we can perform operations and return to our starting |
|
cuda state. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
|
|
self.states_: Dict[str, torch.ByteTensor] = {} |
|
|
|
self.seeds_: Set[int] = set() |
|
|
|
def reset(self) -> None: |
|
"""Set to the initial state (no tracker).""" |
|
self.states_ = {} |
|
self.seeds_ = set() |
|
|
|
def get_states(self) -> Dict[str, torch.ByteTensor]: |
|
"""Get rng states. Copy the dictionary so we have direct |
|
pointers to the states, not just a pointer to the dictionary.""" |
|
states = {} |
|
for name in self.states_: |
|
states[name] = self.states_[name] |
|
return states |
|
|
|
def set_states(self, states: Dict[str, torch.ByteTensor]) -> None: |
|
"""Set the rng states. For efficiency purposes, we do not check |
|
the size of seed for compatibility.""" |
|
self.states_ = states |
|
|
|
def add(self, name: str, seed: int) -> None: |
|
"""Track the rng state. |
|
Arguments: |
|
name (str): The name of the seed |
|
seed (int): The seed value |
|
""" |
|
|
|
if seed in self.seeds_: |
|
raise Exception("seed {} already exists".format(seed)) |
|
self.seeds_.add(seed) |
|
|
|
if name in self.states_: |
|
raise Exception("cuda rng state {} already exists".format(name)) |
|
|
|
orig_rng_state = torch.cuda.get_rng_state() |
|
|
|
torch.cuda.manual_seed(seed) |
|
self.states_[name] = torch.cuda.get_rng_state() |
|
|
|
_set_cuda_rng_state(orig_rng_state) |
|
|
|
@contextlib.contextmanager |
|
def fork(self, name: str = _MODEL_PARALLEL_RNG_TRACKER_NAME) -> Iterator[None]: |
|
"""Fork the cuda rng state, perform operations, and exit with |
|
the original state.""" |
|
|
|
if name not in self.states_: |
|
raise Exception("cuda rng state {} is not added".format(name)) |
|
|
|
orig_cuda_rng_state = torch.cuda.get_rng_state() |
|
|
|
_set_cuda_rng_state(self.states_[name]) |
|
|
|
try: |
|
yield |
|
finally: |
|
|
|
self.states_[name] = torch.cuda.get_rng_state() |
|
|
|
_set_cuda_rng_state(orig_cuda_rng_state) |
|
|
|
|
|
|
|
_CUDA_RNG_STATE_TRACKER = CudaRNGStatesTracker() |
|
|
|
|
|
def get_cuda_rng_tracker() -> CudaRNGStatesTracker: |
|
"""Get cuda rng tracker.""" |
|
return _CUDA_RNG_STATE_TRACKER |
|
|
|
|
|
def model_parallel_cuda_manual_seed(seed: int) -> None: |
|
"""Initialize model parallel cuda seed. |
|
|
|
This function should be called after the model parallel is |
|
initialized. Also, no torch.cuda.manual_seed should be called |
|
after this function. Basically, this is replacement for that |
|
function. |
|
Two set of RNG states are tracked: |
|
default state: This is for data parallelism and is the same among a |
|
set of model parallel GPUs but different across |
|
different model paralle groups. This is used for |
|
example for dropout in the non-model-parallel regions. |
|
model-parallel state: This state is different among a set of model |
|
parallel GPUs, but the same across data parallel |
|
groups. This is used for example for dropout in |
|
model parallel regions. |
|
""" |
|
|
|
offset = seed + 2718 |
|
model_parallel_seed = offset + get_model_parallel_rank() |
|
|
|
data_parallel_seed = seed |
|
|
|
if torch.distributed.get_rank() == 0: |
|
print( |
|
"> initializing model parallel cuda seeds on global rank {}, " |
|
"model parallel rank {}, and data parallel rank {} with " |
|
"model parallel seed: {} and data parallel seed: {}".format( |
|
torch.distributed.get_rank(), |
|
get_model_parallel_rank(), |
|
get_data_parallel_rank(), |
|
model_parallel_seed, |
|
data_parallel_seed, |
|
), |
|
flush=True, |
|
) |
|
if torch.cuda.is_available(): |
|
_CUDA_RNG_STATE_TRACKER.reset() |
|
|
|
torch.cuda.manual_seed(data_parallel_seed) |
|
|
|
_CUDA_RNG_STATE_TRACKER.add(_MODEL_PARALLEL_RNG_TRACKER_NAME, model_parallel_seed) |
|
|
|
|
|
class CheckpointFunction(torch.autograd.Function): |
|
"""This function is adapted from torch.utils.checkpoint with |
|
two main changes: |
|
1) torch.cuda.set_rng_state is replaced with `_set_cuda_rng_state` |
|
2) the states in the model parallel tracker are also properly |
|
tracked/set/reset. |
|
""" |
|
|
|
@staticmethod |
|
def forward(ctx, run_function, *args): |
|
ctx.run_function = run_function |
|
|
|
|
|
ctx.fwd_cpu_rng_state = torch.get_rng_state() |
|
ctx.fwd_cuda_rng_state = torch.cuda.get_rng_state() |
|
ctx.fwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
|
|
|
ctx.save_for_backward(*args) |
|
with torch.no_grad(): |
|
outputs = run_function(*args) |
|
return outputs |
|
|
|
@staticmethod |
|
def backward(ctx, *args): |
|
if not torch.autograd._is_checkpoint_valid(): |
|
raise RuntimeError("Checkpointing is not compatible with .grad(), please use .backward() if possible") |
|
inputs = ctx.saved_tensors |
|
|
|
|
|
bwd_cpu_rng_state = torch.get_rng_state() |
|
bwd_cuda_rng_state = torch.cuda.get_rng_state() |
|
bwd_cuda_rng_state_tracker = get_cuda_rng_tracker().get_states() |
|
|
|
|
|
torch.set_rng_state(ctx.fwd_cpu_rng_state) |
|
_set_cuda_rng_state(ctx.fwd_cuda_rng_state) |
|
get_cuda_rng_tracker().set_states(ctx.fwd_cuda_rng_state_tracker) |
|
|
|
|
|
detached_inputs = detach_variable(inputs) |
|
with torch.enable_grad(): |
|
outputs = ctx.run_function(*detached_inputs) |
|
|
|
|
|
torch.set_rng_state(bwd_cpu_rng_state) |
|
_set_cuda_rng_state(bwd_cuda_rng_state) |
|
get_cuda_rng_tracker().set_states(bwd_cuda_rng_state_tracker) |
|
|
|
if isinstance(outputs, torch.Tensor): |
|
outputs = (outputs,) |
|
torch.autograd.backward(outputs, args) |
|
return (None,) + tuple(inp.grad for inp in detached_inputs) |
|
|
|
|
|
def checkpoint(function, *args): |
|
"""Checkpoint a model or part of the model. |
|
This has been directly copied from torch.utils.checkpoint.""" |
|
return CheckpointFunction.apply(function, *args) |
|
|