kernel
megablocks / tests /ops /topology_test.py
drbh
feat: validate build with original test suite
9c4ca75
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from megablocks import ops
TOPOLOGY_TESTS = (
(1024, 1536, 2),
(1024, 1536, 4),
(1024, 1536, 8),
(1024, 1536, 16),
(1024, 1536, 32),
(1024, 1536, 64),
(1024, 1536, 128),
(1024, 1536, 256),
(1024, 1536, 512),
(16384, 768, 2),
(16384, 768, 4),
(16384, 768, 8),
(16384, 768, 16),
(16384, 768, 32),
(16384, 768, 64),
(16384, 768, 128),
(16384, 768, 256),
(16384, 768, 512),
(16384, 768, 1024),
(8, 14336, 8),
)
@pytest.mark.gpu
@pytest.mark.parametrize(('sl', 'hs', 'ne'), TOPOLOGY_TESTS)
def test_topology(sl: int, hs: int, ne: int):
# Create the data and indices.
blocking = 128
assert hs % blocking == 0
# Randomly assign tokens to experts.
top_expert = torch.randint(0, ne, (sl,)).cuda().int()
tokens_per_expert = ops.histogram(top_expert, ne)
padded_tokens_per_expert = ops.round_up(tokens_per_expert, blocking)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
# Dimensions for the output indices.
output_block_rows = int(padded_bins[-1]) // blocking
output_block_columns = hs // blocking
def topology(
padded_bins: torch.Tensor,
blocking: torch.Tensor,
rows: int,
columns: int,
):
padded_bins = padded_bins.cpu().numpy()
out = np.zeros([rows * columns])
start = 0
for i in range(padded_bins.shape[0]):
end = padded_bins[i] // blocking
while start < end:
for j in range(columns):
out[start * columns + j] = j + i * columns
start += 1
return torch.from_numpy(out).cuda().short()
out = ops.topology(
padded_bins,
blocking,
output_block_rows,
output_block_columns,
)
expected_out = topology(
padded_bins,
blocking,
output_block_rows,
output_block_columns,
)
assert torch.all(torch.eq(out, expected_out))