File size: 1,912 Bytes
9c4ca75 2595c46 9c4ca75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import torch
import megablocks
def test_import():
"""Simple test to check if the module can be imported."""
print("megablocks_moe module imported successfully.")
print("Available functions:", dir(megablocks))
expected_functions = [
"Arguments", "MLP", "MoE", "ParallelDroplessMLP", "ParallelMLP",
"SparseGLU", "SparseMLP", "argsort",
"backend", "cumsum", "dMoE", "exclusive_cumsum",
"get_load_balancing_loss", "grouped_gemm_util", "histogram",
"inclusive_cumsum", "indices", "layers", "ops", "replicate_backward",
"replicate_forward", "sort", "torch"
]
# Check if all expected functions are available
for func in expected_functions:
assert func in dir(megablocks), f"Missing function: {func}"
# exclusive_cumsum
def test_exclusive_cumsum():
"""Test exclusive cumulative sum."""
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
out = torch.empty_like(x)
megablocks.exclusive_cumsum(x, 0, out)
expected = torch.tensor([0, 1, 3, 6], dtype=torch.float32).cuda()
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
print("cumsum output:", out)
# inclusive_cumsum
def test_inclusive_cumsum():
"""Test inclusive cumulative sum."""
x = torch.tensor([1, 2, 3, 4], dtype=torch.int16).cuda()
out = torch.empty_like(x)
megablocks.inclusive_cumsum(x, dim=0, out=out)
expected = torch.tensor([1, 3, 6, 10], dtype=torch.float32).cuda()
assert torch.equal(out, expected), f"Expected {expected}, got {out}"
# histogram
def test_histogram():
"""Test histogram operation."""
x = torch.tensor([0, 1, 1, 2, 2, 2], dtype=torch.int16).cuda()
num_bins = 3
hist = megablocks.histogram(x, num_bins)
expected_hist = torch.tensor([1, 2, 3], dtype=torch.int32).cuda()
assert torch.equal(hist, expected_hist), f"Expected {expected_hist}, got {hist}"
|