|
|
|
|
|
|
|
|
|
""" |
|
The file has been adapted from two fairscale files: |
|
(1) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/moe_layer.py |
|
(2) https://github.com/facebookresearch/fairscale/blob/master/fairscale/nn/moe/top2gate.py |
|
Git commit hash: 34df606902a240567a0d898037ece55c2f1336cf |
|
We retain the following license from the original files: |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
from deepspeed.utils.timer import SynchronizedWallClockTimer |
|
from deepspeed.utils import logger |
|
from deepspeed.utils.bwc import bwc_tensor_model_parallel_world_size |
|
from typing import Callable, Dict, TYPE_CHECKING, Any, Optional, Tuple, Union |
|
|
|
import torch |
|
from torch import Tensor |
|
from torch.nn import Module |
|
import torch.nn.functional as F |
|
from deepspeed.utils import groups |
|
from .mappings import drop_tokens, gather_tokens |
|
|
|
if TYPE_CHECKING: |
|
Base = Module[Tensor] |
|
else: |
|
Base = Module |
|
|
|
TOPK_GATE_TIMER = 'topk_gate' |
|
MOE_TIMER = 'moe' |
|
FIRST_ALLTOALL_TIMER = '1st_a2a' |
|
SECOND_ALLTOALL_TIMER = '2nd_a2a' |
|
|
|
uniform_map: Dict[torch.device, Callable] = {} |
|
gumbel_map: Dict[torch.device, Callable] = {} |
|
exp_selection_uniform_map: Dict[torch.device, Callable] = {} |
|
|
|
try: |
|
|
|
|
|
from tutel import moe as tutel_moe |
|
TUTEL_INSTALLED = True |
|
except: |
|
|
|
TUTEL_INSTALLED = False |
|
pass |
|
|
|
|
|
def multiplicative_jitter(x, device: torch.device, epsilon=1e-2): |
|
""" |
|
Modified from switch transformer paper. mesh transformers |
|
Multiply values by a random number between 1-epsilon and 1+epsilon. |
|
Makes models more resilient to rounding errors introduced by bfloat16. |
|
This seems particularly important for logits. |
|
Args: |
|
x: a torch.tensor |
|
device: torch.device |
|
epsilon: a floating point value |
|
Returns: |
|
a jittered x. |
|
""" |
|
if epsilon == 0: |
|
return x |
|
uniform = uniform_map.get(device) |
|
if uniform is None: |
|
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(1.0 - epsilon, device=device), |
|
high=torch.tensor(1.0 + epsilon, |
|
device=device)).rsample |
|
uniform_map[device] = uniform |
|
return x * uniform(x.shape) |
|
|
|
|
|
def gumbel_rsample(shape: Tuple, device: torch.device) -> Tensor: |
|
gumbel = gumbel_map.get(device) |
|
if gumbel is None: |
|
one = torch.tensor(1.0, device=device) |
|
zero = torch.tensor(0.0, device=device) |
|
gumbel = torch.distributions.gumbel.Gumbel(zero, one).rsample |
|
gumbel_map[device] = gumbel |
|
return gumbel(shape) |
|
|
|
|
|
from deepspeed import comm as dist |
|
|
|
|
|
|
|
|
|
|
|
|
|
class _AllToAll(torch.autograd.Function): |
|
|
|
@staticmethod |
|
def forward(ctx: Any, group: dist.ProcessGroup, input: Tensor) -> Tensor: |
|
ctx.group = group |
|
input = input.contiguous() |
|
output = torch.empty_like(input) |
|
dist.all_to_all_single(output, input, group=group) |
|
return output |
|
|
|
@staticmethod |
|
def backward(ctx: Any, *grad_output: Tensor) -> Tuple[None, Tensor]: |
|
return (None, _AllToAll.apply(ctx.group, *grad_output)) |
|
|
|
|
|
|
|
|
|
USE_EINSUM = True |
|
|
|
|
|
|
|
|
|
def einsum(rule, a, b): |
|
if USE_EINSUM: |
|
return torch.einsum(rule, a, b) |
|
elif rule == 's,se->se': |
|
return a.reshape(a.shape[0], -1) * b |
|
elif rule == 'se,sc->sec': |
|
return a.unsqueeze(2) * b.unsqueeze(1) |
|
elif rule == 'se,se->s': |
|
return torch.bmm(a.unsqueeze(1), b.unsqueeze(2)).reshape(-1) |
|
elif rule == 'se,sec->sec': |
|
return a.unsqueeze(2) * b |
|
elif rule == 'sec,sm->ecm': |
|
s = a.shape[0] |
|
e = a.shape[1] |
|
c = a.shape[2] |
|
m = b.shape[1] |
|
return torch.matmul(a.reshape(s, -1).t(), b).reshape(e, c, m) |
|
elif rule == 'sec,ecm->sm': |
|
return torch.matmul(a.reshape(a.shape[0], -1), b.reshape(-1, b.shape[-1])) |
|
elif rule == 'ks,ksm->sm': |
|
k = b.shape[0] |
|
s = b.shape[1] |
|
m = b.shape[2] |
|
|
|
a = a.t().unsqueeze(1) |
|
|
|
b = b.reshape(k, -1).t().reshape(s, m, k) |
|
|
|
return torch.bmm(a, b.transpose(1, 2)).squeeze(2) |
|
else: |
|
return torch.einsum(rule, a, b) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@torch.jit.script |
|
def _capacity(gates: Tensor, capacity_factor: Tensor, min_capacity: Tensor) -> Tensor: |
|
|
|
num_tokens = gates.shape[0] |
|
num_experts = gates.shape[1] |
|
|
|
|
|
capacity = torch.ceil((num_tokens / num_experts) * capacity_factor).to(torch.int64) |
|
if capacity < min_capacity: |
|
capacity = min_capacity.to(torch.int64) |
|
return capacity |
|
|
|
|
|
@torch.jit.script |
|
def _top_idx(source, k): |
|
return torch.topk(source, k=k, dim=0)[1] |
|
|
|
|
|
@torch.jit.script |
|
def _one_hot_to_float(x, num_classes): |
|
return F.one_hot(x, num_classes=num_classes).float() |
|
|
|
|
|
def top1gating(logits: Tensor, |
|
capacity_factor: float, |
|
min_capacity: int, |
|
used_token: Tensor = None, |
|
noisy_gate_policy: Optional[str] = None, |
|
drop_tokens: bool = True, |
|
use_rts: bool = True, |
|
ep_group: Union[torch.distributed.ProcessGroup, None] = None, |
|
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
"""Implements Top1Gating on logits.""" |
|
if noisy_gate_policy == 'RSample': |
|
logits_w_noise = logits + gumbel_rsample(logits.shape, device=logits.device) |
|
|
|
|
|
gates = F.softmax(logits, dim=1) |
|
capacity = _capacity(gates, torch.tensor(capacity_factor), torch.tensor(min_capacity)) |
|
|
|
|
|
|
|
indices1_s = torch.argmax(logits_w_noise if noisy_gate_policy == 'RSample' else gates, dim=1) |
|
num_experts = int(gates.shape[1]) |
|
mask1 = F.one_hot(indices1_s, num_classes=num_experts) |
|
|
|
|
|
if used_token is not None: |
|
mask1 = einsum("s,se->se", used_token, mask1) |
|
|
|
|
|
exp_counts = torch.sum(mask1, dim=0).detach().to(logits.device) |
|
|
|
|
|
if not drop_tokens: |
|
new_capacity = torch.max(exp_counts).to(logits.device) |
|
|
|
if ep_group is not None: |
|
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) |
|
if groups._get_expert_model_parallel_world_size() == 1: |
|
|
|
|
|
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) |
|
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) |
|
|
|
capacity = min(new_capacity, torch.tensor(mask1.size(0)).to(new_capacity.device)) |
|
|
|
|
|
me = torch.mean(gates, dim=0) |
|
ce = torch.mean(mask1.float(), dim=0) |
|
l_aux = torch.sum(me * ce) * num_experts |
|
|
|
|
|
if use_rts: |
|
uniform = exp_selection_uniform_map.get(logits.device) |
|
if uniform is None: |
|
uniform = torch.distributions.uniform.Uniform(low=torch.tensor(0.0, device=logits.device), |
|
high=torch.tensor(1.0, device=logits.device)).rsample |
|
exp_selection_uniform_map[logits.device] = uniform |
|
|
|
mask1_rand = mask1 * uniform(mask1.shape) |
|
else: |
|
mask1_rand = mask1 |
|
|
|
assert logits.shape[ |
|
0] >= min_capacity, "No. of tokens (batch-size) should be greater than min_capacity. Either set min_capacity to 0 or increase your batch size." |
|
|
|
top_idx = _top_idx(mask1_rand, capacity) |
|
|
|
new_mask1 = mask1 * torch.zeros_like(mask1).scatter_(0, top_idx, 1) |
|
mask1 = new_mask1 |
|
|
|
if use_tutel: |
|
|
|
|
|
indices_mask = mask1.sum(dim=1) * num_experts - 1 |
|
indices1_s = torch.min(indices1_s, indices_mask) |
|
|
|
|
|
if use_tutel: |
|
locations1 = tutel_moe.fast_cumsum_sub_one(mask1) |
|
else: |
|
locations1 = torch.cumsum(mask1, dim=0) - 1 |
|
|
|
if use_tutel: |
|
gates1_s = (gates * mask1).sum(dim=1) |
|
locations1_s = torch.sum(locations1 * mask1, dim=1) |
|
return l_aux, capacity, num_experts, [ |
|
indices1_s, |
|
], [ |
|
locations1_s, |
|
], [ |
|
gates1_s, |
|
], exp_counts |
|
|
|
|
|
locations1_s = torch.sum(locations1 * mask1, dim=1) |
|
|
|
|
|
mask1_float = mask1.float() |
|
gates = gates * mask1_float |
|
|
|
locations1_sc = _one_hot_to_float(locations1_s, capacity) |
|
combine_weights = einsum("se,sc->sec", gates, locations1_sc) |
|
|
|
dispatch_mask = combine_weights.bool() |
|
|
|
return l_aux, combine_weights, dispatch_mask, exp_counts |
|
|
|
|
|
def top2gating(logits: Tensor, |
|
capacity_factor: float, |
|
min_capacity: int, |
|
drop_tokens: bool = True, |
|
ep_group: Union[torch.distributed.ProcessGroup, None] = None, |
|
top2_2nd_expert_sampling: bool = True) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
"""Implements Top2Gating on logits.""" |
|
|
|
gates = F.softmax(logits, dim=1) |
|
|
|
|
|
indices1_s = torch.argmax(gates, dim=1) |
|
num_experts = int(gates.shape[1]) |
|
mask1 = F.one_hot(indices1_s, num_classes=num_experts) |
|
|
|
if top2_2nd_expert_sampling: |
|
|
|
|
|
logits += gumbel_rsample(logits.shape, device=logits.device) |
|
|
|
|
|
logits_except1 = logits.masked_fill(mask1.bool(), float("-inf")) |
|
indices2_s = torch.argmax(logits_except1, dim=1) |
|
mask2 = F.one_hot(indices2_s, num_classes=num_experts) |
|
|
|
|
|
locations1 = torch.cumsum(mask1, dim=0) - 1 |
|
locations2 = torch.cumsum(mask2, dim=0) - 1 |
|
|
|
locations2 += torch.sum(mask1, dim=0, keepdim=True) |
|
|
|
|
|
me = torch.mean(gates, dim=0) |
|
ce = torch.mean(mask1.float(), dim=0) |
|
l_aux = torch.mean(me * ce) * num_experts * num_experts |
|
|
|
|
|
exp_counts = torch.sum(mask1 + mask2, dim=0).detach().to(logits.device) |
|
|
|
if drop_tokens: |
|
|
|
capacity = _capacity(gates, torch.tensor(capacity_factor * 2), torch.tensor(min_capacity)) |
|
mask1 *= torch.lt(locations1, capacity) |
|
mask2 *= torch.lt(locations2, capacity) |
|
else: |
|
|
|
new_capacity = torch.max(exp_counts) |
|
if ep_group is not None: |
|
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) |
|
if groups._get_expert_model_parallel_world_size() == 1: |
|
|
|
|
|
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) |
|
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) |
|
capacity = new_capacity |
|
|
|
|
|
locations1_s = torch.sum(locations1 * mask1, dim=1) |
|
locations2_s = torch.sum(locations2 * mask2, dim=1) |
|
|
|
|
|
mask1_float = mask1.float() |
|
mask2_float = mask2.float() |
|
gates1_s = einsum("se,se->s", gates, mask1_float) |
|
gates2_s = einsum("se,se->s", gates, mask2_float) |
|
denom_s = gates1_s + gates2_s |
|
|
|
denom_s = torch.clamp(denom_s, min=torch.finfo(denom_s.dtype).eps) |
|
gates1_s /= denom_s |
|
gates2_s /= denom_s |
|
|
|
|
|
gates1 = einsum("s,se->se", gates1_s, mask1_float) |
|
gates2 = einsum("s,se->se", gates2_s, mask2_float) |
|
locations1_sc = _one_hot_to_float(locations1_s, capacity) |
|
locations2_sc = _one_hot_to_float(locations2_s, capacity) |
|
combine1_sec = einsum("se,sc->sec", gates1, locations1_sc) |
|
combine2_sec = einsum("se,sc->sec", gates2, locations2_sc) |
|
combine_weights = combine1_sec + combine2_sec |
|
dispatch_mask = combine_weights.bool() |
|
|
|
return l_aux, combine_weights, dispatch_mask, exp_counts |
|
|
|
|
|
def topkgating( |
|
logits: Tensor, |
|
k: int, |
|
capacity_factor: float, |
|
min_capacity: int, |
|
drop_tokens: bool = True, |
|
ep_group: Union[torch.distributed.ProcessGroup, None] = None, |
|
drop_policy: str = "probs", |
|
) -> Tuple[Tensor, Tensor, Tensor, Tensor]: |
|
"""Implements TopKGating on logits.""" |
|
|
|
|
|
|
|
top_gate, top_idx = torch.topk(logits, k=k, dim=1) |
|
|
|
gates = F.softmax(logits, dim=1) |
|
num_experts = int(gates.shape[1]) |
|
|
|
|
|
topk_masked_gates = torch.zeros_like(logits).scatter(1, top_idx, top_gate) |
|
|
|
mask = torch.zeros_like(gates, dtype=torch.bool).scatter_(1, top_idx, 1) |
|
|
|
exp_counts = torch.sum(mask, dim=0).detach().to(logits.device) |
|
|
|
|
|
me = torch.mean(gates, dim=0) |
|
ce = torch.mean(mask.float(), dim=0) |
|
l_aux = torch.mean(me * ce) * num_experts * num_experts / k |
|
|
|
if drop_tokens: |
|
|
|
capacity = _capacity(gates, torch.tensor(capacity_factor * k), torch.tensor(min_capacity)) |
|
|
|
|
|
if drop_policy == 'probs': |
|
capacity_probs, capacity_indices = torch.topk(topk_masked_gates, k=capacity, dim=0, sorted=False) |
|
capacity_mask = torch.zeros_like(logits).scatter(0, capacity_indices, 1) |
|
mask = torch.logical_and(mask, capacity_mask) |
|
locations = torch.cumsum(mask, dim=0) - 1 |
|
|
|
elif drop_policy == "position": |
|
locations = torch.cumsum(mask, dim=0) - 1 |
|
mask *= torch.lt(locations, capacity) |
|
else: |
|
raise ValueError(f"Invalid drop_policy: {drop_policy}") |
|
|
|
else: |
|
|
|
new_capacity = torch.max(exp_counts) |
|
if ep_group is not None: |
|
dist.all_reduce(new_capacity, op=dist.ReduceOp.MAX, group=ep_group) |
|
if groups._get_expert_model_parallel_world_size() == 1: |
|
|
|
|
|
tp = 1 if groups.mpu is None else bwc_tensor_model_parallel_world_size(mpu=groups.mpu) |
|
new_capacity = torch.ceil(new_capacity / tp).mul(tp).to(new_capacity.dtype) |
|
capacity = new_capacity |
|
locations = torch.cumsum(mask, dim=0) - 1 |
|
|
|
|
|
gates_masked = gates * mask |
|
gates_s = torch.sum(gates_masked, dim=-1, keepdim=True) |
|
denom_s = torch.clamp(gates_s, min=torch.finfo(gates_masked.dtype).eps) |
|
gates_masked = gates_masked / denom_s |
|
|
|
|
|
locations_sc = _one_hot_to_float((locations * mask), capacity) |
|
|
|
combine_weights = torch.einsum("se,sec->sec", gates_masked, locations_sc) |
|
|
|
dispatch_mask = combine_weights.bool() |
|
|
|
return l_aux, combine_weights, dispatch_mask, exp_counts |
|
|
|
|
|
class TopKGate(Module): |
|
"""Gate module which implements Top2Gating as described in Gshard_. |
|
:: |
|
|
|
gate = TopKGate(model_dim, num_experts) |
|
l_aux, combine_weights, dispatch_mask = gate(input) |
|
|
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf |
|
|
|
Args: |
|
model_dim (int): |
|
size of model embedding dimension |
|
num_experts (int): |
|
number of experts in model |
|
""" |
|
|
|
wg: torch.nn.Linear |
|
|
|
def __init__(self, |
|
model_dim: int, |
|
num_experts: int, |
|
k: int = 1, |
|
capacity_factor: float = 1.0, |
|
eval_capacity_factor: float = 1.0, |
|
min_capacity: int = 8, |
|
noisy_gate_policy: Optional[str] = None, |
|
drop_tokens: bool = True, |
|
use_rts: bool = True, |
|
ep_group: Union[torch.distributed.ProcessGroup, None] = None, |
|
top2_2nd_expert_sampling: bool = True) -> None: |
|
super().__init__() |
|
|
|
self.wg = torch.nn.Linear(model_dim, num_experts, bias=False) |
|
self.ep_group = ep_group |
|
self.k = k |
|
self.capacity_factor = capacity_factor |
|
self.eval_capacity_factor = eval_capacity_factor |
|
self.min_capacity = min_capacity |
|
self.noisy_gate_policy = noisy_gate_policy |
|
self.timers = SynchronizedWallClockTimer() |
|
self.wall_clock_breakdown = False |
|
self.gate_time = 0.0 |
|
self.drop_tokens = drop_tokens |
|
self.use_rts = use_rts |
|
self.top2_2nd_expert_sampling = top2_2nd_expert_sampling |
|
|
|
def _set_ep_group(self, ep_group): |
|
assert self.ep_group is None, f'Attempting to override an existing ep_group' |
|
self.ep_group = ep_group |
|
|
|
def forward(self, |
|
input: torch.Tensor, |
|
used_token: torch.Tensor = None, |
|
use_tutel: bool = False) -> Tuple[Tensor, Tensor, Tensor]: |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(TOPK_GATE_TIMER).start() |
|
|
|
input_fp32 = input.float() |
|
|
|
if self.noisy_gate_policy == 'Jitter' and self.training: |
|
input_fp32 = multiplicative_jitter(input_fp32, device=input.device) |
|
logits = torch.nn.functional.linear(input_fp32, weight=self.wg.weight.float(), bias=None) |
|
|
|
if self.k == 1: |
|
gate_output = top1gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, |
|
self.min_capacity, used_token, self.noisy_gate_policy if self.training else None, |
|
self.drop_tokens, self.use_rts, self.ep_group, use_tutel) |
|
|
|
elif self.k == 2: |
|
gate_output = top2gating(logits, self.capacity_factor if self.training else self.eval_capacity_factor, |
|
self.min_capacity, self.drop_tokens, self.ep_group, self.top2_2nd_expert_sampling) |
|
else: |
|
gate_output = topkgating(logits, self.k, |
|
self.capacity_factor if self.training else self.eval_capacity_factor, |
|
self.min_capacity, self.drop_tokens, self.ep_group) |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(TOPK_GATE_TIMER).stop() |
|
self.gate_time = self.timers(TOPK_GATE_TIMER).elapsed(reset=False) |
|
|
|
return gate_output |
|
|
|
|
|
class MOELayer(Base): |
|
"""MOELayer module which implements MixtureOfExperts as described in Gshard_. |
|
:: |
|
|
|
gate = TopKGate(model_dim, num_experts) |
|
moe = MOELayer(gate, expert) |
|
output = moe(input) |
|
l_aux = moe.l_aux |
|
|
|
.. Gshard_: https://arxiv.org/pdf/2006.16668.pdf |
|
|
|
Args: |
|
gate (torch.nn.Module): |
|
gate network |
|
expert (torch.nn.Module): |
|
expert network |
|
""" |
|
|
|
def __init__(self, |
|
gate: Module, |
|
experts: Module, |
|
ep_group_name, |
|
ep_size, |
|
num_local_experts: int, |
|
use_tutel: bool = False) -> None: |
|
super().__init__() |
|
self.gate = gate |
|
self.experts = experts |
|
self.ep_group = None |
|
self.ep_size = ep_size |
|
self.ep_group_name = ep_group_name |
|
self.num_local_experts = num_local_experts |
|
self.time_falltoall = 0.0 |
|
self.time_salltoall = 0.0 |
|
self.time_moe = 0.0 |
|
self.timers = SynchronizedWallClockTimer() |
|
self.wall_clock_breakdown = False |
|
|
|
self.use_tutel = use_tutel and TUTEL_INSTALLED and gate.k == 1 |
|
|
|
if self.use_tutel: |
|
logger.info('Using Tutel optimizations.') |
|
elif use_tutel and not TUTEL_INSTALLED: |
|
logger.warning("Tutel optimization requested but not installed. " |
|
"Proceeding without Tutel.") |
|
elif use_tutel and TUTEL_INSTALLED and gate.k != 1: |
|
logger.warning("To enable Tutel optimization, use top-1 instead of top-2 gate. " |
|
"Proceeding without Tutel.") |
|
|
|
def _set_ep_group(self, ep_group): |
|
self.ep_group = ep_group |
|
self.gate._set_ep_group(ep_group) |
|
|
|
def forward(self, *input: Tensor, **kwargs: Any) -> Tensor: |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(MOE_TIMER).start() |
|
|
|
|
|
d_model = input[0].shape[-1] |
|
|
|
|
|
|
|
|
|
reshaped_input = input[0].reshape(-1, d_model) |
|
|
|
if self.use_tutel: |
|
self.l_aux, C, E, indices_, locations_, gates_, self.exp_counts = self.gate(reshaped_input, input[1], True) |
|
S, M = reshaped_input.size(0), reshaped_input.size(1) |
|
|
|
if not hasattr(self, '_tutel_dispatcher'): |
|
self._tutel_dispatcher = tutel_moe.fast_dispatcher(E, C, M, dispatch_dtype=reshaped_input.dtype) |
|
self._tutel_dispatcher.update(indices_, locations_, gates_, capacity=C) |
|
dispatched_input = self._tutel_dispatcher.encode(reshaped_input) |
|
else: |
|
self.l_aux, combine_weights, dispatch_mask, self.exp_counts = self.gate(reshaped_input, input[1]) |
|
dispatched_input = einsum("sec,sm->ecm", dispatch_mask.type_as(input[0]), reshaped_input) |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(FIRST_ALLTOALL_TIMER).start() |
|
|
|
tensor_model_world_size = bwc_tensor_model_parallel_world_size(groups.mpu) |
|
if tensor_model_world_size > 1: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
dispatched_input = drop_tokens(dispatched_input, dim=1) |
|
|
|
dispatched_input = _AllToAll.apply(self.ep_group, dispatched_input) |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(FIRST_ALLTOALL_TIMER).stop() |
|
self.time_falltoall = self.timers(FIRST_ALLTOALL_TIMER).elapsed(reset=False) |
|
|
|
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: |
|
|
|
|
|
|
|
dispatched_input = gather_tokens(dispatched_input, dim=1) |
|
|
|
|
|
dispatched_input = dispatched_input.reshape(self.ep_size, self.num_local_experts, -1, d_model) |
|
expert_output = self.experts(dispatched_input) |
|
|
|
expert_output = expert_output.reshape(self.ep_size * self.num_local_experts, -1, d_model) |
|
if tensor_model_world_size > 1 and groups._get_expert_model_parallel_world_size() > 1: |
|
|
|
|
|
|
|
expert_output = drop_tokens(expert_output, dim=1) |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(SECOND_ALLTOALL_TIMER).start() |
|
|
|
expert_output = _AllToAll.apply(self.ep_group, expert_output) |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(SECOND_ALLTOALL_TIMER).stop() |
|
self.time_salltoall = self.timers(SECOND_ALLTOALL_TIMER).elapsed(reset=False) |
|
|
|
if tensor_model_world_size > 1: |
|
|
|
|
|
|
|
expert_output = gather_tokens(expert_output, dim=1) |
|
|
|
if self.use_tutel: |
|
combined_output = self._tutel_dispatcher.decode(expert_output.view(E * C, M)) |
|
else: |
|
combined_output = einsum("sec,ecm->sm", combine_weights.type_as(input[0]), expert_output) |
|
|
|
a = combined_output.reshape(input[0].shape) |
|
|
|
if self.wall_clock_breakdown: |
|
self.timers(MOE_TIMER).stop() |
|
self.time_moe = self.timers(MOE_TIMER).elapsed(reset=False) |
|
|
|
return a |
|
|