|
|
|
|
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
|
|
from megablocks import ops |
|
|
|
PADDED_SCATTER_TESTS = [ |
|
(4, 2, 2, 2), |
|
(4, 2, 2, 1), |
|
(4, 2, 2, 1), |
|
(4, 2, 2, 1), |
|
(4, 2, 2, 2), |
|
(4, 2, 2, 2), |
|
(1024, 1, 4, 1), |
|
(1024, 1, 4, 2), |
|
(1024, 1, 4, 4), |
|
(1024, 1, 4, 1), |
|
(1024, 1, 4, 2), |
|
(1024, 1, 4, 4), |
|
(1024, 1, 4, 1), |
|
(1024, 1, 4, 2), |
|
(1024, 1, 4, 4), |
|
(1024, 1, 64, 1), |
|
(1024, 1, 64, 2), |
|
(1024, 1, 64, 4), |
|
(1024, 1, 128, 1), |
|
(1024, 1, 128, 2), |
|
(1024, 1, 128, 4), |
|
(1024, 1536, 4, 1), |
|
(1024, 1536, 4, 2), |
|
(1024, 1536, 4, 4), |
|
(1024, 1536, 4, 4), |
|
(1024, 1536, 4, 4), |
|
(1024, 1536, 64, 1), |
|
(1024, 1536, 64, 2), |
|
(1024, 1536, 64, 4), |
|
(1024, 1536, 128, 1), |
|
(1024, 1536, 128, 2), |
|
(1024, 1536, 128, 4), |
|
(1024, 1536, 128, 1), |
|
(1024, 1536, 128, 1), |
|
(16384, 768, 4, 1), |
|
(16384, 768, 4, 2), |
|
(16384, 768, 4, 4), |
|
(16384, 768, 64, 1), |
|
(16384, 768, 64, 2), |
|
(16384, 768, 64, 4), |
|
(16384, 768, 128, 1), |
|
(16384, 768, 128, 2), |
|
(16384, 768, 128, 4), |
|
(16384, 1, 4, 1), |
|
(16384, 1, 4, 2), |
|
(16384, 1, 4, 4), |
|
(16384, 1, 64, 1), |
|
(16384, 1, 64, 2), |
|
(16384, 1, 64, 4), |
|
(16384, 1, 128, 1), |
|
(16384, 1, 128, 2), |
|
(16384, 1, 128, 4), |
|
(16384, 1, 128, 2), |
|
(16384, 1, 128, 2), |
|
] |
|
|
|
|
|
def _to_numpy(x: torch.Tensor) -> np.ndarray: |
|
return x.detach().cpu().numpy() |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(( |
|
'sl', |
|
'hs', |
|
'ne', |
|
'top_k', |
|
), PADDED_SCATTER_TESTS) |
|
def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int): |
|
|
|
x = torch.randn((sl, hs), requires_grad=True).cuda().half() |
|
|
|
|
|
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() |
|
bin_ids, indices = ops.sort(top_expert) |
|
tokens_per_expert = ops.histogram(top_expert, ne) |
|
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128) |
|
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) |
|
bins = ops.inclusive_cumsum(tokens_per_expert, 0) |
|
|
|
|
|
weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half() |
|
|
|
|
|
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k) |
|
|
|
def padded_scatter( |
|
x: torch.Tensor, |
|
indices: torch.Tensor, |
|
bin_ids: torch.Tensor, |
|
weights: torch.Tensor, |
|
bins: torch.Tensor, |
|
padded_bins: torch.Tensor, |
|
top_k: int, |
|
): |
|
x = x.detach().cpu().numpy() |
|
indices: np.ndarray = _to_numpy(indices) |
|
bin_ids: np.ndarray = _to_numpy(bin_ids) |
|
weights: np.ndarray = _to_numpy(weights) |
|
bins: np.ndarray = _to_numpy(bins) |
|
padded_bins: np.ndarray = _to_numpy(padded_bins) |
|
|
|
out = np.zeros((indices.shape[0] // top_k, hs)) |
|
out_idx = 0 |
|
for i in range(len(bins)): |
|
in_idx = 0 if i == 0 else padded_bins[i - 1] |
|
end = bins[i] |
|
while out_idx < end: |
|
store_idx = indices[out_idx] |
|
scale = weights[store_idx] |
|
store_idx //= top_k |
|
|
|
out[store_idx, :] += scale * x[in_idx, :] |
|
out_idx += 1 |
|
in_idx += 1 |
|
return torch.from_numpy(out).cuda().half() |
|
|
|
out = ops.padded_scatter( |
|
x, |
|
indices, |
|
bin_ids, |
|
weights, |
|
bins, |
|
padded_bins, |
|
top_k, |
|
) |
|
expected_out = padded_scatter( |
|
x, |
|
indices, |
|
bin_ids, |
|
weights, |
|
bins, |
|
padded_bins, |
|
top_k, |
|
) |
|
|
|
out.backward(torch.randn_like(out)) |
|
|
|
|
|
|
|
assert np.testing.assert_allclose( |
|
_to_numpy(out), |
|
_to_numpy(expected_out), |
|
rtol=5e-3, |
|
) is None |
|
|