kernel
megablocks / tests /ops_test.py
drbh
fix: bump builds
6756875
import torch
import megablocks
import unittest
from absl.testing import parameterized
# import itertools
# import numpy as np
def allclose(x, y, pct=2.0):
mask = torch.isclose(x, y, rtol=1e-5)
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100
if pct_diff > pct:
print(x[torch.logical_not(mask)], y[torch.logical_not(mask)])
print("{:.2f}% of values not close.".format(pct_diff))
return False
return True
def add_flags(x):
out = []
for y in x:
for trans_b in (False, True):
out.append(y + (trans_b, False))
# TODO: Revisit enabling batch_sizes_on_device
# for batch_sizes_on_device in (False, True):
# out.append(y + (trans_b, batch_sizes_on_device))
return out
_TEST_PROBLEMS = add_flags((
(1, 128, 128, 128),
(8, 128, 128, 128),
(16, 128, 128, 128),
(1, 128, 256, 512),
(8, 128, 256, 512),
(16, 128, 256, 512),
))
def randn(bs, x, y):
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
return out.cuda().to(torch.bfloat16)
def gmm(a, b, batch_sizes, trans_b=False):
batch_sizes = batch_sizes.cpu().numpy()
out = []
start = 0
for i, size in enumerate(batch_sizes):
rhs = b[i, :, :].t() if trans_b else b[i, :, :]
out.append(a[start:start + size, :] @ rhs)
start += size
return torch.cat(out)
@parameterized.parameters(*_TEST_PROBLEMS)
class OpsTest(parameterized.TestCase):
def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, n, k) if trans_b else randn(z, k, n)
batch_sizes = torch.tensor([m] * z)
if batch_sizes_on_device:
batch_sizes = batch_sizes.cuda()
a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
# out = ops.gmm(a, b, batch_sizes, trans_b)
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
# print("out", out)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
self.assertTrue(allclose(out, expected_out))
# Check gradients.
out.sum().backward()
expected_out.sum().backward()
self.assertTrue(allclose(a.grad, a_ref.grad))
self.assertTrue(allclose(b.grad, b_ref.grad))
def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device):
torch.manual_seed(0)
a = randn(z, m, k).view(-1, k)
b = randn(z, n, k) if trans_b else randn(z, k, n)
dist = torch.rand(z, )
dist /= dist.sum()
batch_sizes = (dist * m).to(torch.long)
error = m * z - batch_sizes.sum()
batch_sizes[-1] += error
assert batch_sizes.sum() == (m * z)
if batch_sizes_on_device:
batch_sizes = batch_sizes.cuda()
a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)
self.assertTrue(allclose(out, expected_out))
# Check gradients.
out.sum().backward()
expected_out.sum().backward()
self.assertTrue(allclose(a.grad, a_ref.grad))
# TODO: Review to ensure that the gradients are correct.
# self.assertTrue(allclose(b.grad, b_ref.grad))
# @parameterized.parameters(False, True)
@parameterized.parameters(False, False)
class EdgeCasesTest(unittest.TestCase):
def testGroupedGemm_ZeroSize(self, batch_sizes_on_device):
torch.manual_seed(0)
m = 16384
k = 4096
n = 14336
num_experts = 8
a = randn(num_experts, m // num_experts, k).view(-1, k)
b = randn(num_experts, k, n)
batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long)
if batch_sizes_on_device:
batch_sizes = batch_sizes.cuda()
a.requires_grad_(True)
b.requires_grad_(True)
a_ref = a.detach().clone().requires_grad_(True)
b_ref = b.detach().clone().requires_grad_(True)
out = megablocks.gg_ops.gmm(a, b, batch_sizes)
expected_out = gmm(a_ref, b_ref, batch_sizes)
self.assertTrue(allclose(out, expected_out))
# Check gradients.
out.sum().backward()
expected_out.sum().backward()
self.assertTrue(allclose(a.grad, a_ref.grad))
self.assertTrue(allclose(b.grad, b_ref.grad))
def testGroupedGemm_ZeroK(self, batch_sizes_on_device):
sz = 128
total_tokens = 192
a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16)
c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16)
batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long)
if batch_sizes_on_device:
batch_sizes = batch_sizes.cuda()
megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c)
self.assertTrue((c[0] == 0).all())
self.assertTrue((c[1] == 128).all())
self.assertTrue((c[2] == 0).all())
self.assertTrue((c[3] == 64).all())
if __name__ == '__main__':
unittest.main()