|
|
|
|
|
|
|
|
|
|
|
import collections.abc as abc |
|
from dataclasses import dataclass |
|
from math import inf |
|
from typing import Any, Callable, Dict, List, Optional |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
|
|
@dataclass |
|
class Workhandle: |
|
handle: Any |
|
callback: Optional[Callable] = None |
|
|
|
|
|
def get_global_rank(group: Any, rank: int) -> int: |
|
if group is dist.group.WORLD: |
|
return rank |
|
|
|
return dist.distributed_c10d._get_global_rank(group, rank) |
|
|
|
|
|
|
|
def recursive_copy_to_device(value: Any, non_blocking: bool, device: torch.device) -> Any: |
|
""" |
|
Recursively searches lists, tuples, dicts and copies tensors to device if |
|
possible. Non-tensor values are passed as-is in the result. |
|
|
|
NOTE: These are all copies, so if there are two objects that reference |
|
the same object, then after this call, there will be two different objects |
|
referenced on the device. |
|
""" |
|
|
|
if isinstance(value, torch.Tensor): |
|
return value.to(device, non_blocking=non_blocking) |
|
|
|
if isinstance(value, (list, tuple)): |
|
values = [] |
|
for val in value: |
|
values.append(recursive_copy_to_device(val, non_blocking=non_blocking, device=device)) |
|
|
|
return values if isinstance(value, list) else tuple(values) |
|
|
|
if isinstance(value, abc.Mapping): |
|
device_val: Dict[str, Any] = {} |
|
for key, val in value.items(): |
|
device_val[key] = recursive_copy_to_device(val, non_blocking=non_blocking, device=device) |
|
|
|
return device_val |
|
|
|
return value |
|
|
|
|
|
def calc_grad_norm(parameters: List[torch.nn.Parameter], p: float) -> torch.Tensor: |
|
r"""Calculate gradient norm of an iterable of parameters. |
|
Returns: |
|
Total norm of the parameters (viewed as a single vector). |
|
""" |
|
if isinstance(parameters, torch.Tensor): |
|
parameters = [parameters] |
|
parameters = list(filter(lambda par: par.grad is not None, parameters)) |
|
|
|
if len(parameters) == 0: |
|
return torch.tensor(0.0) |
|
p = float(p) |
|
if p == inf: |
|
local_norm = max(par.grad.detach().abs().max() for par in parameters) |
|
else: |
|
|
|
local_norm = torch.norm(torch.stack([torch.norm(par.grad.detach(), p, dtype=torch.float32) for par in parameters]), p).to(dtype=parameters[0].dtype) |
|
return local_norm |
|
|