|
import torch |
|
import megablocks |
|
|
|
import unittest |
|
from absl.testing import parameterized |
|
|
|
|
|
|
|
|
|
|
|
def allclose(x, y, pct=2.0): |
|
mask = torch.isclose(x, y, rtol=1e-5) |
|
pct_diff = (mask.numel() - mask.sum()) / mask.numel() * 100 |
|
if pct_diff > pct: |
|
print(x[torch.logical_not(mask)], y[torch.logical_not(mask)]) |
|
print("{:.2f}% of values not close.".format(pct_diff)) |
|
return False |
|
return True |
|
|
|
|
|
def add_flags(x): |
|
out = [] |
|
for y in x: |
|
for trans_b in (False, True): |
|
out.append(y + (trans_b, False)) |
|
|
|
|
|
|
|
|
|
return out |
|
|
|
|
|
_TEST_PROBLEMS = add_flags(( |
|
(1, 128, 128, 128), |
|
(8, 128, 128, 128), |
|
(16, 128, 128, 128), |
|
(1, 128, 256, 512), |
|
(8, 128, 256, 512), |
|
(16, 128, 256, 512), |
|
)) |
|
|
|
|
|
def randn(bs, x, y): |
|
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
|
return out.cuda().to(torch.bfloat16) |
|
|
|
|
|
def gmm(a, b, batch_sizes, trans_b=False): |
|
batch_sizes = batch_sizes.cpu().numpy() |
|
|
|
out = [] |
|
start = 0 |
|
for i, size in enumerate(batch_sizes): |
|
rhs = b[i, :, :].t() if trans_b else b[i, :, :] |
|
out.append(a[start:start + size, :] @ rhs) |
|
start += size |
|
return torch.cat(out) |
|
|
|
|
|
@parameterized.parameters(*_TEST_PROBLEMS) |
|
class OpsTest(parameterized.TestCase): |
|
|
|
def testGroupedGemm_FixedSizes(self, z, m, k, n, trans_b, batch_sizes_on_device): |
|
torch.manual_seed(0) |
|
a = randn(z, m, k).view(-1, k) |
|
b = randn(z, n, k) if trans_b else randn(z, k, n) |
|
batch_sizes = torch.tensor([m] * z) |
|
if batch_sizes_on_device: |
|
batch_sizes = batch_sizes.cuda() |
|
|
|
a.requires_grad_(True) |
|
b.requires_grad_(True) |
|
a_ref = a.detach().clone().requires_grad_(True) |
|
b_ref = b.detach().clone().requires_grad_(True) |
|
|
|
|
|
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
|
|
|
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
|
self.assertTrue(allclose(out, expected_out)) |
|
|
|
|
|
out.sum().backward() |
|
expected_out.sum().backward() |
|
self.assertTrue(allclose(a.grad, a_ref.grad)) |
|
self.assertTrue(allclose(b.grad, b_ref.grad)) |
|
|
|
def testGroupedGemm_VariableSizes(self, z, m, k, n, trans_b, batch_sizes_on_device): |
|
torch.manual_seed(0) |
|
a = randn(z, m, k).view(-1, k) |
|
b = randn(z, n, k) if trans_b else randn(z, k, n) |
|
|
|
dist = torch.rand(z, ) |
|
dist /= dist.sum() |
|
batch_sizes = (dist * m).to(torch.long) |
|
error = m * z - batch_sizes.sum() |
|
batch_sizes[-1] += error |
|
assert batch_sizes.sum() == (m * z) |
|
if batch_sizes_on_device: |
|
batch_sizes = batch_sizes.cuda() |
|
|
|
a.requires_grad_(True) |
|
b.requires_grad_(True) |
|
a_ref = a.detach().clone().requires_grad_(True) |
|
b_ref = b.detach().clone().requires_grad_(True) |
|
|
|
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
|
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
|
self.assertTrue(allclose(out, expected_out)) |
|
|
|
|
|
out.sum().backward() |
|
expected_out.sum().backward() |
|
self.assertTrue(allclose(a.grad, a_ref.grad)) |
|
|
|
|
|
|
|
|
|
|
|
|
|
@parameterized.parameters(False, False) |
|
class EdgeCasesTest(unittest.TestCase): |
|
|
|
def testGroupedGemm_ZeroSize(self, batch_sizes_on_device): |
|
torch.manual_seed(0) |
|
m = 16384 |
|
k = 4096 |
|
n = 14336 |
|
num_experts = 8 |
|
|
|
a = randn(num_experts, m // num_experts, k).view(-1, k) |
|
b = randn(num_experts, k, n) |
|
batch_sizes = torch.tensor([219, 2246, 5, 8103, 1, 1117, 4693, 0]).to(torch.long) |
|
if batch_sizes_on_device: |
|
batch_sizes = batch_sizes.cuda() |
|
|
|
a.requires_grad_(True) |
|
b.requires_grad_(True) |
|
a_ref = a.detach().clone().requires_grad_(True) |
|
b_ref = b.detach().clone().requires_grad_(True) |
|
|
|
out = megablocks.gg_ops.gmm(a, b, batch_sizes) |
|
expected_out = gmm(a_ref, b_ref, batch_sizes) |
|
self.assertTrue(allclose(out, expected_out)) |
|
|
|
|
|
out.sum().backward() |
|
expected_out.sum().backward() |
|
self.assertTrue(allclose(a.grad, a_ref.grad)) |
|
self.assertTrue(allclose(b.grad, b_ref.grad)) |
|
|
|
def testGroupedGemm_ZeroK(self, batch_sizes_on_device): |
|
sz = 128 |
|
total_tokens = 192 |
|
|
|
a = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) |
|
b = torch.ones(total_tokens, sz).cuda().to(torch.bfloat16) |
|
c = torch.ones(4, sz, sz).cuda().to(torch.bfloat16) |
|
batch_sizes = torch.tensor([0, 128, 0, 64]).to(torch.long) |
|
if batch_sizes_on_device: |
|
batch_sizes = batch_sizes.cuda() |
|
|
|
megablocks.gg_backend.gmm(a, b, batch_sizes, trans_a=True, c=c) |
|
self.assertTrue((c[0] == 0).all()) |
|
self.assertTrue((c[1] == 128).all()) |
|
self.assertTrue((c[2] == 0).all()) |
|
self.assertTrue((c[3] == 64).all()) |
|
|
|
|
|
if __name__ == '__main__': |
|
unittest.main() |
|
|