kernel
megablocks / tests /ops /padded_scatter_test.py
drbh
feat: validate build with original test suite
9c4ca75
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from megablocks import ops
PADDED_SCATTER_TESTS = [
(4, 2, 2, 2),
(4, 2, 2, 1),
(4, 2, 2, 1),
(4, 2, 2, 1),
(4, 2, 2, 2),
(4, 2, 2, 2),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 64, 1),
(1024, 1, 64, 2),
(1024, 1, 64, 4),
(1024, 1, 128, 1),
(1024, 1, 128, 2),
(1024, 1, 128, 4),
(1024, 1536, 4, 1),
(1024, 1536, 4, 2),
(1024, 1536, 4, 4),
(1024, 1536, 4, 4),
(1024, 1536, 4, 4),
(1024, 1536, 64, 1),
(1024, 1536, 64, 2),
(1024, 1536, 64, 4),
(1024, 1536, 128, 1),
(1024, 1536, 128, 2),
(1024, 1536, 128, 4),
(1024, 1536, 128, 1),
(1024, 1536, 128, 1),
(16384, 768, 4, 1),
(16384, 768, 4, 2),
(16384, 768, 4, 4),
(16384, 768, 64, 1),
(16384, 768, 64, 2),
(16384, 768, 64, 4),
(16384, 768, 128, 1),
(16384, 768, 128, 2),
(16384, 768, 128, 4),
(16384, 1, 4, 1),
(16384, 1, 4, 2),
(16384, 1, 4, 4),
(16384, 1, 64, 1),
(16384, 1, 64, 2),
(16384, 1, 64, 4),
(16384, 1, 128, 1),
(16384, 1, 128, 2),
(16384, 1, 128, 4),
(16384, 1, 128, 2),
(16384, 1, 128, 2),
]
def _to_numpy(x: torch.Tensor) -> np.ndarray:
return x.detach().cpu().numpy()
@pytest.mark.gpu
@pytest.mark.parametrize((
'sl',
'hs',
'ne',
'top_k',
), PADDED_SCATTER_TESTS)
def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
# Create the data and indices.
x = torch.randn((sl, hs), requires_grad=True).cuda().half()
# Randomly assign tokens to experts.
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
bin_ids, indices = ops.sort(top_expert)
tokens_per_expert = ops.histogram(top_expert, ne)
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
# Sample weights for the scatter reduce.
weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
# Gather the data to prepare for backwards.
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
x = x.detach().cpu().numpy()
indices: np.ndarray = _to_numpy(indices)
bin_ids: np.ndarray = _to_numpy(bin_ids)
weights: np.ndarray = _to_numpy(weights)
bins: np.ndarray = _to_numpy(bins)
padded_bins: np.ndarray = _to_numpy(padded_bins)
out = np.zeros((indices.shape[0] // top_k, hs))
out_idx = 0
for i in range(len(bins)):
in_idx = 0 if i == 0 else padded_bins[i - 1]
end = bins[i]
while out_idx < end:
store_idx = indices[out_idx]
scale = weights[store_idx]
store_idx //= top_k
out[store_idx, :] += scale * x[in_idx, :]
out_idx += 1
in_idx += 1
return torch.from_numpy(out).cuda().half()
out = ops.padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
top_k,
)
expected_out = padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
top_k,
)
out.backward(torch.randn_like(out)) # sanity check backward pass
# NOTE: We need to check approximate equality because the scatter reduce uses atomics.
# np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
assert np.testing.assert_allclose(
_to_numpy(out),
_to_numpy(expected_out),
rtol=5e-3,
) is None