|
|
|
|
|
|
|
import numpy as np |
|
import pytest |
|
import torch |
|
|
|
from megablocks import ops |
|
|
|
TOPOLOGY_TESTS = ( |
|
(1024, 1536, 2), |
|
(1024, 1536, 4), |
|
(1024, 1536, 8), |
|
(1024, 1536, 16), |
|
(1024, 1536, 32), |
|
(1024, 1536, 64), |
|
(1024, 1536, 128), |
|
(1024, 1536, 256), |
|
(1024, 1536, 512), |
|
(16384, 768, 2), |
|
(16384, 768, 4), |
|
(16384, 768, 8), |
|
(16384, 768, 16), |
|
(16384, 768, 32), |
|
(16384, 768, 64), |
|
(16384, 768, 128), |
|
(16384, 768, 256), |
|
(16384, 768, 512), |
|
(16384, 768, 1024), |
|
(8, 14336, 8), |
|
) |
|
|
|
|
|
@pytest.mark.gpu |
|
@pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS) |
|
def test_topology(sl: int, hs: int, ne: int): |
|
|
|
blocking = 128 |
|
assert hs % blocking == 0 |
|
|
|
|
|
top_expert = torch.randint(0, ne, (sl,)).cuda().int() |
|
tokens_per_expert = ops.histogram(top_expert, ne) |
|
padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking) |
|
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0) |
|
|
|
|
|
output_block_rows = int(padded_bins[-1]) // blocking |
|
output_block_columns = hs // blocking |
|
|
|
def topology( |
|
padded_bins: torch.Tensor, |
|
blocking: torch.Tensor, |
|
rows: int, |
|
columns: int, |
|
): |
|
padded_bins = padded_bins.cpu().numpy() |
|
|
|
out = np.zeros([rows * columns]) |
|
start = 0 |
|
for i in range(padded_bins.shape[0]): |
|
end = padded_bins[i] // blocking |
|
while start < end: |
|
for j in range(columns): |
|
out[start * columns + j] = j + i * columns |
|
start += 1 |
|
return torch.from_numpy(out).cuda().short() |
|
|
|
out = ops.topology( |
|
padded_bins, |
|
blocking, |
|
output_block_rows, |
|
output_block_columns, |
|
) |
|
expected_out = topology( |
|
padded_bins, |
|
blocking, |
|
output_block_rows, |
|
output_block_columns, |
|
) |
|
assert torch.all(torch.eq(out, expected_out)) |
|
|