kernel
megablocks / tests /test_mb_moe.py
drbh
feat: validate build with original test suite
9c4ca75
import torch
import megablocks
def test_import():
"""Simple test to check if the module can be imported."""
print("megablocks_moe module imported successfully.")
print("Available functions:", dir(megablocks))
expected_functions = [
"Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
"SparseGLU", "SparseMLP", "argsort",
"backend", "cumsum", "dMoE", "exclusive_cumsum",
"get_load_balancing_loss", "grouped_gemm_util", "histogram",
"inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
"replicate_forward", "sort", "torch"
]
# Check if all expected functions are available
for func in expected_functions:
assert func in dir(megablocks), f"Missing function: {func}"
# exclusive_cumsum
def test_exclusive_cumsum():
"""Test exclusive cumulative sum."""
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
out = torch.empty_like(x)
megablocks.exclusive_cumsum(x, 0, out)
expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
print("cumsum output:", out)
# inclusive_cumsum
def test_inclusive_cumsum():
"""Test inclusive cumulative sum."""
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
out = torch.empty_like(x)
megablocks.inclusive_cumsum(x, dim=0, out=out)
expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
# histogram
def test_histogram():
"""Test histogram operation."""
x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
num_bins = 3
hist = megablocks.histogram(x, num_bins)
expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"