kernel
megablocks / tests /ops /replicate_test.py
drbh
feat: validate build with original test suite
9c4ca75
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
try:
from megablocks._ops import ops as backend # type: ignore
except ModuleNotFoundError as e:
raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e
from megablocks import ops
def promote_scalar(x: torch.Tensor) -> torch.Tensor:
return x.view(1) if not len(x.size()) else x
REPLICATE_TESTS = [
(8, 1, 1),
(8, 2, 1),
(8, 4, 1),
(8, 8, 1),
(8, 2, 2),
(8, 4, 2),
(8, 8, 2),
(8, 2, 4),
(8, 4, 4),
(8, 8, 4),
(8, 2, 8),
(8, 4, 8),
(8, 8, 8),
(16384, 2, 1),
(16384, 4, 1),
(16384, 8, 1),
(16384, 16, 1),
(16384, 32, 1),
(16384, 64, 1),
(16384, 128, 1),
(16384, 2, 2),
(16384, 4, 2),
(16384, 8, 2),
(16384, 16, 2),
(16384, 32, 2),
(16384, 64, 2),
(16384, 128, 2),
(16384, 2, 4),
(16384, 4, 4),
(16384, 8, 4),
(16384, 16, 4),
(16384, 32, 4),
(16384, 64, 4),
(16384, 128, 4),
(16384, 2, 8),
(16384, 4, 8),
(16384, 8, 8),
(16384, 16, 8),
(16384, 32, 8),
(16384, 64, 8),
(16384, 128, 8),
]
@pytest.mark.gpu
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
def test_replicate(tokens: int, num_centers: int, top_k: int):
tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
bins = ops.inclusive_cumsum(tokens_per_center, 0)
bins = promote_scalar(bins)
center_weights = torch.randn(top_k, num_centers).cuda().half()
def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
x = x.cpu().numpy()
bins = bins.cpu().numpy()
out = np.zeros((x.shape[0], num_outputs))
for batch_idx in range(x.shape[0]):
start = 0
for i, end in enumerate(bins):
value = x[batch_idx, i]
while start < end:
out[batch_idx, start] = value
start += 1
return torch.from_numpy(out).cuda().half()
out = ops.replicate(center_weights, bins, tokens)
expected_out = replicate(center_weights, bins, tokens)
assert torch.all(torch.eq(out, expected_out))
@pytest.mark.gpu
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
bins = ops.inclusive_cumsum(tokens_per_center, 0)
bins = promote_scalar(bins)
center_weights = torch.randn(top_k, num_centers).cuda().half()
grad = ops.replicate(center_weights, bins, tokens)
out = torch.empty_like(center_weights)
backend.replicate_backward(grad, bins, out)
expected_out = center_weights * tokens_per_center.view([1, num_centers])
# NOTE: This floating-point reduction could be a problem for training stability and accuracy.
assert torch.allclose(out, expected_out, rtol=1e-2)