|
|
|
|
|
|
|
from functools import partial |
|
|
|
import pytest |
|
import torch |
|
|
|
from megablocks.layers.arguments import Arguments |
|
from megablocks.layers.moe import MoE, batched_load_balancing_loss, clear_load_balancing_loss |
|
from megablocks.layers.router import batched_router_zloss, clear_router_zloss |
|
from tests.layers.architectures import FFN |
|
|
|
_FORWARD_TESTS = ( |
|
(16, 1024, 512, 1, 1), |
|
(16, 1024, 512, 2, 1), |
|
(16, 1024, 512, 4, 1), |
|
(16, 1024, 512, 8, 1), |
|
(8, 2048, 512, 1, 1), |
|
(8, 2048, 512, 2, 1), |
|
(8, 2048, 512, 4, 1), |
|
(16, 1024, 512, 2, 2), |
|
(16, 1024, 512, 4, 2), |
|
(16, 1024, 512, 4, 4), |
|
(16, 1024, 512, 8, 2), |
|
(16, 1024, 512, 8, 4), |
|
(16, 1024, 512, 8, 8), |
|
) |
|
|
|
_DENSE_TESTS = ( |
|
(16, 1024, 512), |
|
(8, 2048, 512), |
|
) |
|
|
|
|
|
def construct_moe( |
|
hidden_size: int, |
|
ffn_hidden_size: int, |
|
moe_num_experts: int = 1, |
|
moe_capacity_factor: int = 1, |
|
moe_top_k: int = 1, |
|
moe_zloss_weight: float = 0, |
|
): |
|
|
|
|
|
try: |
|
import triton |
|
if triton.__version__ >= '3.2.0': |
|
pytest.skip('Sparse MLP is not supported with triton >=3.2.0') |
|
except ImportError: |
|
pass |
|
|
|
init_method = partial(torch.nn.init.normal_, mean=0.0, std=0.1) |
|
args = Arguments( |
|
hidden_size=hidden_size, |
|
ffn_hidden_size=ffn_hidden_size, |
|
moe_num_experts=moe_num_experts, |
|
moe_capacity_factor=moe_capacity_factor, |
|
moe_top_k=moe_top_k, |
|
init_method=init_method, |
|
moe_zloss_weight=moe_zloss_weight, |
|
) |
|
|
|
mlp = FFN(args) |
|
moe_mlp = MoE(args) |
|
|
|
mlp.cuda(torch.cuda.current_device()).half() |
|
moe_mlp.cuda(torch.cuda.current_device()).half() |
|
|
|
|
|
if moe_num_experts == 1: |
|
with torch.no_grad(): |
|
mlp.w1.copy_(moe_mlp.experts.mlp.w1.squeeze()) |
|
mlp.w2.copy_(moe_mlp.experts.mlp.w2.squeeze()) |
|
return args, mlp, moe_mlp |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) |
|
def test_moe_forward(bs: int, sl: int, hs: int, num_experts: int, top_k: int): |
|
x = torch.randn(sl, bs, hs).half().cuda() |
|
|
|
_, _, layer = construct_moe( |
|
hidden_size=hs, |
|
ffn_hidden_size=hs * 2, |
|
moe_num_experts=num_experts, |
|
moe_top_k=top_k, |
|
) |
|
|
|
out, _ = layer(x) |
|
assert out.shape == x.shape |
|
clear_load_balancing_loss() |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) |
|
def test_moe_forward_backward( |
|
bs: int, |
|
sl: int, |
|
hs: int, |
|
num_experts: int, |
|
top_k: int, |
|
): |
|
x = torch.randn(sl, bs, hs).half().cuda() |
|
x.requires_grad_(True) |
|
|
|
args, _, layer = construct_moe( |
|
hidden_size=hs, |
|
ffn_hidden_size=hs * 2, |
|
moe_num_experts=num_experts, |
|
moe_top_k=top_k, |
|
) |
|
|
|
out, _ = layer(x) |
|
assert out.shape == x.shape |
|
|
|
loss = out.sum() + batched_load_balancing_loss(args) |
|
loss.backward() |
|
layer.zero_grad(set_to_none=True) |
|
x.grad = None |
|
clear_load_balancing_loss() |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('bs', 'sl', 'hs', 'num_experts', 'top_k'), _FORWARD_TESTS) |
|
def test_moe_forward_backward_with_zloss( |
|
bs: int, |
|
sl: int, |
|
hs: int, |
|
num_experts: int, |
|
top_k: int, |
|
): |
|
x = torch.randn(sl, bs, hs).half().cuda() |
|
x.requires_grad_(True) |
|
|
|
args, _, layer = construct_moe( |
|
hidden_size=hs, |
|
ffn_hidden_size=hs * 2, |
|
moe_num_experts=num_experts, |
|
moe_top_k=top_k, |
|
moe_zloss_weight=1e-3, |
|
) |
|
|
|
out, _ = layer(x) |
|
assert out.shape == x.shape |
|
|
|
loss = out.sum() + batched_load_balancing_loss(args) |
|
loss.backward() |
|
layer.zero_grad(set_to_none=True) |
|
x.grad = None |
|
clear_load_balancing_loss() |
|
clear_router_zloss() |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) |
|
def test_moe_forward_vs_dense(bs: int, sl: int, hs: int): |
|
x = torch.randn(sl, bs, hs).half().cuda() |
|
|
|
_, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) |
|
|
|
expected_out = mlp(x) |
|
out, _ = moe_mlp(x) |
|
assert out.shape == x.shape == expected_out.shape |
|
assert torch.allclose(out, expected_out) |
|
clear_load_balancing_loss() |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('bs', 'sl', 'hs'), _DENSE_TESTS) |
|
def test_moe_forward_backward_vs_dense(bs: int, sl: int, hs: int): |
|
x = torch.randn(sl, bs, hs).half().cuda() |
|
x.requires_grad_(True) |
|
|
|
_, mlp, moe_mlp = construct_moe(hidden_size=hs, ffn_hidden_size=hs * 2) |
|
|
|
out, _ = moe_mlp(x) |
|
loss = out.sum() |
|
loss.backward() |
|
w1_grad = moe_mlp.experts.mlp.w1.grad.detach().squeeze() |
|
w2_grad = moe_mlp.experts.mlp.w2.grad.detach().squeeze() |
|
moe_mlp.zero_grad(set_to_none=True) |
|
x.grad = None |
|
clear_load_balancing_loss() |
|
|
|
expected_out = mlp(x) |
|
expected_loss = expected_out.sum() |
|
expected_loss.backward() |
|
expected_w1_grad = mlp.w1.grad.detach() |
|
expected_w2_grad = mlp.w2.grad.detach() |
|
mlp.zero_grad(set_to_none=True) |
|
x.grad = None |
|
|
|
|
|
assert w1_grad.shape == expected_w1_grad.shape |
|
assert w2_grad.shape == expected_w2_grad.shape |
|
assert torch.allclose(w1_grad, expected_w1_grad) |
|
assert torch.allclose(w2_grad, expected_w2_grad) |
|
clear_load_balancing_loss() |
|
|