|
|
|
|
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
|
|
from megablocks import ops |
|
|
|
_BINNED_SCATTER_TESTS = ( |
|
(4, 2, 2, 1), |
|
(4, 2, 2, 2), |
|
(4, 2, 2, 4), |
|
(1024, 1536, 4, 1), |
|
(1024, 1536, 4, 2), |
|
(1024, 1536, 4, 4), |
|
(1024, 1536, 64, 1), |
|
(1024, 1536, 64, 2), |
|
(1024, 1536, 64, 4), |
|
(1024, 1536, 128, 1), |
|
(1024, 1536, 128, 2), |
|
(1024, 1536, 128, 4), |
|
(16384, 768, 4, 1), |
|
(16384, 768, 4, 2), |
|
(16384, 768, 4, 4), |
|
(16384, 768, 64, 1), |
|
(16384, 768, 64, 2), |
|
(16384, 768, 64, 4), |
|
(16384, 768, 128, 1), |
|
(16384, 768, 128, 2), |
|
(16384, 768, 128, 4), |
|
) |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS) |
|
def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int): |
|
|
|
ec = (sl * top_k) // ne |
|
|
|
|
|
x = torch.randn((sl, hs)).cuda().half() |
|
|
|
|
|
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int() |
|
_, indices = ops.sort(top_expert) |
|
bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0) |
|
|
|
|
|
weights = torch.rand((sl * top_k,)).cuda().half() |
|
|
|
x = ops.binned_gather(x, indices, bins, ec, top_k) |
|
|
|
def binned_scatter( |
|
x: torch.Tensor, |
|
indices: torch.Tensor, |
|
weights: torch.Tensor, |
|
bins: torch.Tensor, |
|
top_k: int, |
|
): |
|
x = x.cpu().numpy() |
|
indices = indices.cpu().numpy() |
|
weights = weights.cpu().numpy() |
|
bins = bins.cpu().numpy() |
|
start = 0 |
|
out = np.zeros((sl, hs)) |
|
for i in range(ne): |
|
end = bins[i] |
|
for j in range(min(ec, end - start)): |
|
index = indices[start + j] |
|
scale = weights[index] |
|
index //= top_k |
|
|
|
out[index, :] += scale * x[i, j, :] |
|
start = end |
|
return torch.from_numpy(out).cuda().half() |
|
|
|
out = ops.binned_scatter(x, indices, weights, bins, top_k) |
|
expected_out = binned_scatter(x, indices, weights, bins, top_k) |
|
|
|
|
|
|
|
assert np.testing.assert_allclose( |
|
out.cpu(), |
|
expected_out.cpu(), |
|
rtol=5e-3, |
|
) is None |
|
|