File size: 6,131 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
# Copyright The Lightning team.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Any, List, Optional
import torch
from torch import Tensor
from torch.nn import functional as F # noqa: N812
from typing_extensions import Literal
def reduce(x: Tensor, reduction: Literal["elementwise_mean", "sum", "none", None]) -> Tensor:
"""Reduces a given tensor by a given reduction method.
Args:
x: the tensor, which shall be reduced
reduction: a string specifying the reduction method ('elementwise_mean', 'none', 'sum')
Return:
reduced Tensor
Raise:
ValueError if an invalid reduction parameter was given
"""
if reduction == "elementwise_mean":
return torch.mean(x)
if reduction == "none" or reduction is None:
return x
if reduction == "sum":
return torch.sum(x)
raise ValueError("Reduction parameter unknown.")
def class_reduce(
num: Tensor,
denom: Tensor,
weights: Tensor,
class_reduction: Literal["micro", "macro", "weighted", "none", None] = "none",
) -> Tensor:
"""Reduce classification metrics of the form ``num / denom * weights``.
For example for calculating standard accuracy the num would be number of true positives per class, denom would be
the support per class, and weights would be a tensor of 1s.
Args:
num: numerator tensor
denom: denominator tensor
weights: weights for each class
class_reduction: reduction method for multiclass problems:
- ``'micro'``: calculate metrics globally (default)
- ``'macro'``: calculate metrics for each label, and find their unweighted mean.
- ``'weighted'``: calculate metrics for each label, and find their weighted mean.
- ``'none'`` or ``None``: returns calculated metric per class
Raises:
ValueError:
If ``class_reduction`` is none of ``"micro"``, ``"macro"``, ``"weighted"``, ``"none"`` or ``None``.
"""
valid_reduction = ("micro", "macro", "weighted", "none", None)
fraction = torch.sum(num) / torch.sum(denom) if class_reduction == "micro" else num / denom
# We need to take care of instances where the denom can be 0
# for some (or all) classes which will produce nans
fraction[fraction != fraction] = 0
if class_reduction == "micro":
return fraction
if class_reduction == "macro":
return torch.mean(fraction)
if class_reduction == "weighted":
return torch.sum(fraction * (weights.float() / torch.sum(weights)))
if class_reduction == "none" or class_reduction is None:
return fraction
raise ValueError(f"Reduction parameter {class_reduction} unknown. Choose between one of these: {valid_reduction}")
def _simple_gather_all_tensors(result: Tensor, group: Any, world_size: int) -> List[Tensor]:
with torch.no_grad():
gathered_result = [torch.zeros_like(result) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result, group)
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result
def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tensor]:
"""Gather all tensors from several ddp processes onto a list that is broadcast to all processes.
Works on tensors that have the same number of dimensions, but where each dimension may differ. In this case
tensors are padded, gathered and then trimmed to secure equal workload for all processes.
Args:
result: the value to sync
group: the process group to gather results from. Defaults to all processes (world)
Return:
list with size equal to the process group where element i corresponds to result tensor from process i
"""
if group is None:
group = torch.distributed.group.WORLD
# convert tensors to contiguous format
result = result.contiguous()
world_size = torch.distributed.get_world_size(group)
torch.distributed.barrier(group=group)
# if the tensor is scalar, things are easy
if result.ndim == 0:
return _simple_gather_all_tensors(result, group, world_size)
# 1. Gather sizes of all tensors
local_size = torch.tensor(result.shape, device=result.device)
local_sizes = [torch.zeros_like(local_size) for _ in range(world_size)]
torch.distributed.all_gather(local_sizes, local_size, group=group)
max_size = torch.stack(local_sizes).max(dim=0).values
all_sizes_equal = all(all(ls == max_size) for ls in local_sizes)
# 2. If shapes are all the same, then do a simple gather:
if all_sizes_equal:
return _simple_gather_all_tensors(result, group, world_size)
# 3. If not, we need to pad each local tensor to maximum size, gather and then truncate
with torch.no_grad():
pad_dims = []
pad_by = (max_size - local_size).detach().cpu()
for val in reversed(pad_by):
pad_dims.append(0)
pad_dims.append(val.item())
result_padded = F.pad(result, pad_dims)
gathered_result = [torch.zeros_like(result_padded) for _ in range(world_size)]
torch.distributed.all_gather(gathered_result, result_padded, group)
for idx, item_size in enumerate(local_sizes):
slice_param = [slice(dim_size) for dim_size in item_size]
gathered_result[idx] = gathered_result[idx][slice_param]
# to propagate autograd graph from local rank
gathered_result[torch.distributed.get_rank(group)] = result
return gathered_result
|