|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Optional |
|
|
|
import torch |
|
import torch.nn.functional as F |
|
import torch.nn.init as init |
|
from torch.nn.parameter import Parameter |
|
|
|
from .initialize import get_model_parallel_rank, get_model_parallel_world_size |
|
from .mappings import ( |
|
copy_to_model_parallel_region, |
|
gather_from_model_parallel_region, |
|
reduce_from_model_parallel_region, |
|
scatter_to_model_parallel_region, |
|
) |
|
from .utils import VocabUtility, divide_and_check_no_remainder |
|
|
|
|
|
def _initialize_affine_weight( |
|
weight: torch.Tensor, |
|
out_features: int, |
|
in_features: int, |
|
per_partition_size: int, |
|
partition_dim: int, |
|
init_method: Callable[[torch.Tensor], torch.Tensor], |
|
stride: int = 1, |
|
return_master_weight: bool = False, |
|
) -> Optional[torch.Tensor]: |
|
"""Initialize affine weight for model parallel. |
|
|
|
Build the master weight on all processes and scatter |
|
the relevant chunk.""" |
|
|
|
|
|
world_size = get_model_parallel_world_size() |
|
if world_size == 1: |
|
init_method(weight) |
|
if return_master_weight: |
|
return weight |
|
return None |
|
|
|
|
|
master_weight = torch.empty(out_features, in_features, dtype=weight.dtype, requires_grad=False) |
|
init_method(master_weight) |
|
|
|
|
|
per_partition_per_stride_size = divide_and_check_no_remainder(per_partition_size, stride) |
|
weight_list = torch.split(master_weight, per_partition_per_stride_size, dim=partition_dim) |
|
rank = get_model_parallel_rank() |
|
my_weight_list = weight_list[rank::world_size] |
|
|
|
with torch.no_grad(): |
|
torch.cat(my_weight_list, dim=partition_dim, out=weight) |
|
if return_master_weight: |
|
return master_weight |
|
return None |
|
|
|
|
|
class VocabParallelEmbedding(torch.nn.Module): |
|
"""Embedding parallelized in the vocabulary dimension. |
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default |
|
values are kept. |
|
Arguments: |
|
num_embeddings: vocabulary size. |
|
embedding_dim: size of hidden state. |
|
init_method: method to initialize weights. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
max_norm: Optional[float] = None, |
|
norm_type: float = 2.0, |
|
scale_grad_by_freq: bool = False, |
|
sparse: bool = False, |
|
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, |
|
) -> None: |
|
super(VocabParallelEmbedding, self).__init__() |
|
|
|
self.num_embeddings = num_embeddings |
|
self.embedding_dim = embedding_dim |
|
self.padding_idx = padding_idx |
|
self.max_norm = max_norm |
|
self.norm_type = norm_type |
|
self.scale_grad_by_freq = scale_grad_by_freq |
|
self.sparse = sparse |
|
self._weight = None |
|
|
|
self.vocab_start_index, self.vocab_end_index = VocabUtility.vocab_range_from_global_vocab_size( |
|
self.num_embeddings, get_model_parallel_rank(), get_model_parallel_world_size() |
|
) |
|
self.num_embeddings_per_partition = self.vocab_end_index - self.vocab_start_index |
|
|
|
|
|
self.weight = Parameter(torch.Tensor(self.num_embeddings_per_partition, self.embedding_dim)) |
|
|
|
_initialize_affine_weight( |
|
self.weight, self.num_embeddings, self.embedding_dim, self.num_embeddings_per_partition, 0, init_method |
|
) |
|
|
|
def forward(self, input_: torch.Tensor) -> torch.Tensor: |
|
|
|
input_mask = (input_ < self.vocab_start_index) | (input_ >= self.vocab_end_index) |
|
|
|
masked_input = input_.clone() - self.vocab_start_index |
|
masked_input[input_mask] = 0 |
|
|
|
output_parallel = F.embedding( |
|
masked_input, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
|
|
output_parallel[input_mask, :] = 0.0 |
|
|
|
output = reduce_from_model_parallel_region(output_parallel) |
|
return output |
|
|
|
|
|
class ParallelEmbedding(torch.nn.Module): |
|
"""Embedding parallelized in the embedding dimension. |
|
|
|
This is mainly adapted from torch.nn.Embedding and all the default |
|
values are kept. |
|
Arguments: |
|
num_embeddings: vocabulary size. |
|
embedding_dim: size of hidden state. |
|
init_method: method to initialize weights. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_embeddings: int, |
|
embedding_dim: int, |
|
padding_idx: Optional[int] = None, |
|
max_norm: Optional[float] = None, |
|
norm_type: float = 2.0, |
|
scale_grad_by_freq: bool = False, |
|
sparse: bool = False, |
|
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, |
|
keep_master_weight_for_test: bool = False, |
|
) -> None: |
|
super(ParallelEmbedding, self).__init__() |
|
|
|
self.num_embeddings = num_embeddings |
|
self.embedding_dim = embedding_dim |
|
self.padding_idx = padding_idx |
|
self.max_norm = max_norm |
|
self.norm_type = scale_grad_by_freq |
|
self.scale_grad_by_freq = scale_grad_by_freq |
|
self.sparse = sparse |
|
self._weight = None |
|
|
|
world_size = get_model_parallel_world_size() |
|
self.embedding_dim_per_partition = divide_and_check_no_remainder(self.embedding_dim, world_size) |
|
|
|
|
|
self.weight = Parameter(torch.Tensor(self.num_embeddings, self.embedding_dim_per_partition)) |
|
|
|
_initialize_affine_weight( |
|
self.weight, |
|
self.num_embeddings, |
|
self.embedding_dim, |
|
self.embedding_dim_per_partition, |
|
1, |
|
init_method, |
|
stride=1, |
|
return_master_weight=False, |
|
) |
|
|
|
def forward(self, input_: torch.Tensor) -> torch.Tensor: |
|
input_parallel = copy_to_model_parallel_region(input_) |
|
output_parallel = F.embedding( |
|
input_parallel, |
|
self.weight, |
|
self.padding_idx, |
|
self.max_norm, |
|
self.norm_type, |
|
self.scale_grad_by_freq, |
|
self.sparse, |
|
) |
|
output = gather_from_model_parallel_region(output_parallel) |
|
return output |
|
|
|
|
|
class ColumnParallelLinear(torch.nn.Module): |
|
"""Linear layer with column parallelism. |
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along |
|
its second dimension as A = [A_1, ..., A_p]. |
|
|
|
Arguments: |
|
in_features: first dimension of matrix A. |
|
out_features: second dimension of matrix A. |
|
bias: If true, add bias |
|
gather_output: If true, call all-gether on output and make Y avaiable |
|
to all GPUs, otherwise, every GPU will have its output |
|
which is Y_i = XA_i |
|
init_method: method to initialize weights. Note that bias is always set |
|
to zero. |
|
stride: For the strided linear layers. |
|
keep_master_weight_for_test: This was added for testing and should be |
|
set to False. It returns the master weights |
|
used for initialization. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
gather_output: bool = True, |
|
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, |
|
stride: int = 1, |
|
keep_master_weight_for_test: bool = False, |
|
) -> None: |
|
super(ColumnParallelLinear, self).__init__() |
|
|
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.gather_output = gather_output |
|
|
|
world_size = get_model_parallel_world_size() |
|
self.output_size_per_partition = divide_and_check_no_remainder(out_features, world_size) |
|
|
|
|
|
|
|
|
|
self.weight = Parameter(torch.Tensor(self.output_size_per_partition, self.in_features)) |
|
if bias: |
|
self.bias = Parameter(torch.Tensor(self.output_size_per_partition)) |
|
|
|
with torch.no_grad(): |
|
self.bias.zero_() |
|
else: |
|
self.register_parameter("bias", None) |
|
|
|
|
|
self.master_weight = _initialize_affine_weight( |
|
self.weight, |
|
self.out_features, |
|
self.in_features, |
|
self.output_size_per_partition, |
|
0, |
|
init_method, |
|
stride=stride, |
|
return_master_weight=keep_master_weight_for_test, |
|
) |
|
|
|
def get_master_weight(self) -> torch.Tensor: |
|
return gather_from_model_parallel_region(self.weight.data.transpose(0, 1)).transpose_(0, 1) |
|
|
|
def forward(self, input_: torch.Tensor) -> torch.Tensor: |
|
|
|
input_parallel = copy_to_model_parallel_region(input_) |
|
|
|
output_parallel = F.linear(input_parallel, self.weight, self.bias) |
|
if self.gather_output: |
|
|
|
output = gather_from_model_parallel_region(output_parallel) |
|
else: |
|
output = output_parallel |
|
return output |
|
|
|
|
|
class RowParallelLinear(torch.nn.Module): |
|
"""Linear layer with row parallelism. |
|
|
|
The linear layer is defined as Y = XA + b. A is parallelized along |
|
its first dimension and X along its second dimension as: |
|
- - |
|
| A_1 | |
|
| . | |
|
A = | . | X = [X_1, ..., X_p] |
|
| . | |
|
| A_p | |
|
- - |
|
Arguments: |
|
in_features: first dimension of matrix A. |
|
out_features: second dimension of matrix A. |
|
bias: If true, add bias. Note that bias is not parallelized. |
|
input_is_parallel: If true, we assume that the input is already |
|
split across the GPUs and we do not split |
|
again. |
|
init_method: method to initialize weights. Note that bias is always set |
|
to zero. |
|
stride: For the strided linear layers. |
|
keep_master_weight_for_test: This was added for testing and should be |
|
set to False. It returns the master weights |
|
used for initialization. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
in_features: int, |
|
out_features: int, |
|
bias: bool = True, |
|
input_is_parallel: bool = False, |
|
init_method: Callable[[torch.Tensor], torch.Tensor] = init.xavier_normal_, |
|
stride: int = 1, |
|
keep_master_weight_for_test: bool = False, |
|
): |
|
super(RowParallelLinear, self).__init__() |
|
|
|
|
|
self.in_features = in_features |
|
self.out_features = out_features |
|
self.input_is_parallel = input_is_parallel |
|
|
|
world_size = get_model_parallel_world_size() |
|
self.input_size_per_partition = divide_and_check_no_remainder(in_features, world_size) |
|
|
|
|
|
|
|
|
|
self.weight = Parameter(torch.Tensor(self.out_features, self.input_size_per_partition)) |
|
if bias: |
|
self.bias = Parameter(torch.Tensor(self.out_features)) |
|
|
|
with torch.no_grad(): |
|
self.bias.zero_() |
|
else: |
|
self.register_parameter("bias", None) |
|
|
|
|
|
self.master_weight = _initialize_affine_weight( |
|
self.weight, |
|
self.out_features, |
|
self.in_features, |
|
self.input_size_per_partition, |
|
1, |
|
init_method, |
|
stride=stride, |
|
return_master_weight=keep_master_weight_for_test, |
|
) |
|
|
|
def get_master_weight(self) -> torch.Tensor: |
|
return gather_from_model_parallel_region(self.weight.data) |
|
|
|
def forward(self, input_: torch.Tensor) -> torch.Tensor: |
|
|
|
if self.input_is_parallel: |
|
input_parallel = input_ |
|
else: |
|
input_parallel = scatter_to_model_parallel_region(input_) |
|
|
|
output_parallel = F.linear(input_parallel, self.weight) |
|
|
|
output_ = reduce_from_model_parallel_region(output_parallel) |
|
if self.bias is not None: |
|
output = output_ + self.bias |
|
else: |
|
output = output_ |
|
return output |
|
|