|
import torch |
|
import megablocks |
|
|
|
def test_import(): |
|
"""Simple test to check if the module can be imported.""" |
|
print("megablocks_moe module imported successfully.") |
|
print("Available functions:", dir(megablocks)) |
|
|
|
expected_functions = [ |
|
"Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP", |
|
"SparseGLU", "SparseMLP", "argsort", |
|
"backend", "cumsum", "dMoE", "exclusive_cumsum", |
|
"get_load_balancing_loss", "grouped_gemm_util", "histogram", |
|
"inclusive_cumsum", "indices", "layers", "ops", "replicate_backward", |
|
"replicate_forward", "sort", "torch" |
|
] |
|
|
|
|
|
for func in expected_functions: |
|
assert func in dir(megablocks), f"Missing function: {func}" |
|
|
|
|
|
def test_exclusive_cumsum(): |
|
"""Test exclusive cumulative sum.""" |
|
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda() |
|
out = torch.empty_like(x) |
|
megablocks.exclusive_cumsum(x, 0, out) |
|
expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda() |
|
assert torch.equal(out, expected), f"Expected {expected}, got {out}" |
|
print("cumsum output:", out) |
|
|
|
|
|
def test_inclusive_cumsum(): |
|
"""Test inclusive cumulative sum.""" |
|
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda() |
|
out = torch.empty_like(x) |
|
megablocks.inclusive_cumsum(x, dim=0, out=out) |
|
expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda() |
|
assert torch.equal(out, expected), f"Expected {expected}, got {out}" |
|
|
|
|
|
def test_histogram(): |
|
"""Test histogram operation.""" |
|
x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda() |
|
num_bins = 3 |
|
hist = megablocks.histogram(x, num_bins) |
|
expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda() |
|
assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}" |
|
|