|
import torch |
|
import megablocks |
|
|
|
|
|
def randn(bs, x, y): |
|
out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x) |
|
return out.cuda().to(torch.bfloat16) |
|
|
|
|
|
def gmm(a, b, batch_sizes, trans_b=False): |
|
batch_sizes = batch_sizes.cpu().numpy() |
|
|
|
out = [] |
|
start = 0 |
|
for i, size in enumerate(batch_sizes): |
|
rhs = b[i, :, :].t() if trans_b else b[i, :, :] |
|
out.append(a[start : start + size, :] @ rhs) |
|
start += size |
|
return torch.cat(out) |
|
|
|
|
|
def test_gmm(): |
|
z = 1 |
|
m = 128 |
|
n = 128 |
|
k = 128 |
|
trans_b = False |
|
batch_sizes_on_device = False |
|
|
|
|
|
|
|
torch.manual_seed(0) |
|
a = randn(z, m, k).view(-1, k) |
|
b = randn(z, n, k) if trans_b else randn(z, k, n) |
|
batch_sizes = torch.tensor([m] * z) |
|
if batch_sizes_on_device: |
|
batch_sizes = batch_sizes.cuda() |
|
|
|
a.requires_grad_(True) |
|
b.requires_grad_(True) |
|
a_ref = a.detach().clone().requires_grad_(True) |
|
b_ref = b.detach().clone().requires_grad_(True) |
|
|
|
|
|
out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b) |
|
print("out", out) |
|
|
|
expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b) |
|
|
|
assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}" |
|
|
|
out.sum().backward() |
|
|
|
expected_out.sum().backward() |
|
assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}" |
|
assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}" |
|
print("Test passed successfully!") |