kernel
megablocks / tests /test_gg.py
drbh
feat: vendor grouped gemm
3224250
import torch
import megablocks
def randn(bs, x, y):
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
return out.cuda().to(torch.bfloat16)
def gmm(a, b, batch_sizes, trans_b=False):
batch_sizes = batch_sizes.cpu().numpy()
out = []
start = 0
for i, size in enumerate(batch_sizes):
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
out.append(a[start : start + size, :] @ rhs)
start += size
return torch.cat(out)
def test_gmm():
z = 1
m = 128
n = 128
k = 128
trans_b = False
batch_sizes_on_device = False
# TODO: fix to enable batch_sizes_on_device
# batch_sizes_on_device = True
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, n, k) if trans_b else randn(z, k, n)
batch_sizes = torch.tensor([m] * z)
if batch_sizes_on_device:
batch_sizes = batch_sizes.cuda()
a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
# out = ops.gmm(a, b, batch_sizes, trans_b)
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
print("out", out)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"
out.sum().backward()
expected_out.sum().backward()
assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
print("Test passed successfully!")