|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Callable, Dict, Tuple |
|
|
|
import torch |
|
from torch import Tensor |
|
import torch.nn.functional as F |
|
|
|
gumbel_map: Dict[torch.device, Callable] = {} |
|
|
|
|
|
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: |
|
gumbel = gumbel_map.get(device) |
|
if gumbel is None: |
|
one = torch.tensor(1.0, device=device) |
|
zero = torch.tensor(0.0, device=device) |
|
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample |
|
gumbel_map[device] = gumbel |
|
return gumbel(shape) |
|
|
|
|
|
def one_hot(tensor: torch.Tensor, num_classes: int) -> Tensor: |
|
"""Workaround for https://github.com/pytorch/pytorch/issues/55579""" |
|
assert num_classes > 0, "num_classes must be a positive integer" |
|
ret = torch.zeros(tensor.shape + (num_classes,), device=tensor.device, dtype=tensor.dtype) |
|
ret.scatter_(-1, tensor.unsqueeze(-1), 1) |
|
return ret |
|
|
|
|
|
def top2gating(logits: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
|
"""Implements Top2Gating on logits.""" |
|
|
|
gates = F.softmax(logits, dim=1, dtype=torch.float) |
|
|
|
|
|
num_tokens = gates.shape[0] |
|
num_experts = gates.shape[1] |
|
|
|
capacity = 2 * num_tokens // num_experts |
|
assert num_tokens % num_experts == 0 |
|
|
|
|
|
indices1_s = torch.argmax(gates, dim=1) |
|
mask1 = one_hot(indices1_s, num_classes=num_experts) |
|
|
|
|
|
|
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) |
|
|
|
logits_except1 = logits_w_noise.masked_fill(mask1.bool(), float("-inf")) |
|
indices2_s = torch.argmax(logits_except1, dim=1) |
|
mask2 = one_hot(indices2_s, num_classes=num_experts) |
|
|
|
|
|
locations1 = torch.cumsum(mask1, dim=0) - 1 |
|
locations2 = torch.cumsum(mask2, dim=0) - 1 |
|
|
|
locations2 += torch.sum(mask1, dim=0, keepdim=True) |
|
|
|
|
|
me = torch.mean(gates, dim=0) |
|
ce = torch.mean(mask1.float(), dim=0) |
|
l_aux = torch.mean(me * ce) |
|
|
|
|
|
mask1 *= torch.lt(locations1, capacity) |
|
mask2 *= torch.lt(locations2, capacity) |
|
|
|
|
|
locations1_s = torch.sum(locations1 * mask1, dim=1) |
|
locations2_s = torch.sum(locations2 * mask2, dim=1) |
|
|
|
|
|
gates1_s = (gates * mask1).sum(dim=1) |
|
gates2_s = (gates * mask2).sum(dim=1) |
|
denom_s = gates1_s + gates2_s |
|
|
|
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) |
|
gates1_s /= denom_s |
|
gates2_s /= denom_s |
|
|
|
|
|
gates1 = gates1_s.unsqueeze(-1) * mask1 |
|
gates2 = gates2_s.unsqueeze(-1) * mask2 |
|
locations1_sc = one_hot(locations1_s, num_classes=capacity) |
|
locations2_sc = one_hot(locations2_s, num_classes=capacity) |
|
combine1_sec = gates1.unsqueeze(2) * locations1_sc.unsqueeze(1) |
|
combine2_sec = gates2.unsqueeze(2) * locations2_sc.unsqueeze(1) |
|
combine_weights = combine1_sec + combine2_sec |
|
dispatch_mask = combine_weights.bool() |
|
|
|
return l_aux.to(logits.dtype), combine_weights.to(logits.dtype), dispatch_mask |
|
|
|
|
|
class Top2Gate(torch.nn.Module): |
|
"""Gate module which implements Top2Gating as described in Gshard_. |
|
:: |
|
|
|
gate = Top2Gate(model_dim, num_experts) |
|
l_aux, combine_weights, dispatch_mask = gate(input) |
|
|
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf |
|
|
|
Args: |
|
model_dim (int): |
|
size of model embedding dimension |
|
num_experts (ints): |
|
number of experts in model |
|
""" |
|
|
|
wg: torch.nn.Linear |
|
|
|
def __init__( |
|
self, |
|
model_dim: int, |
|
num_experts: int, |
|
) -> None: |
|
super().__init__() |
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) |
|
|
|
def forward(self, input: torch.Tensor) -> Tuple[Tensor, Tensor, Tensor]: |
|
logits = self.wg(input) |
|
return top2gating(logits) |
|
|