kernel
File size: 1,631 Bytes
3224250
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
import torch
import megablocks


def randn(bs, x, y):
    out = (torch.rand(bs, x, y) - 0.5 * 2) / (y * x)
    return out.cuda().to(torch.bfloat16)


def gmm(a, b, batch_sizes, trans_b=False):
    batch_sizes = batch_sizes.cpu().numpy()

    out = []
    start = 0
    for i, size in enumerate(batch_sizes):
        rhs = b[i, :, :].t() if trans_b else b[i, :, :]
        out.append(a[start : start + size, :] @ rhs)
        start += size
    return torch.cat(out)


def test_gmm():
    z = 1
    m = 128
    n = 128
    k = 128
    trans_b = False
    batch_sizes_on_device = False
    # TODO: fix to enable batch_sizes_on_device
    # batch_sizes_on_device = True

    torch.manual_seed(0)
    a = randn(z, m, k).view(-1, k)
    b = randn(z, n, k) if trans_b else randn(z, k, n)
    batch_sizes = torch.tensor([m] * z)
    if batch_sizes_on_device:
        batch_sizes = batch_sizes.cuda()

    a.requires_grad_(True)
    b.requires_grad_(True)
    a_ref = a.detach().clone().requires_grad_(True)
    b_ref = b.detach().clone().requires_grad_(True)

    # out = ops.gmm(a, b, batch_sizes, trans_b)
    out = megablocks.gg_ops.gmm(a, b, batch_sizes, trans_b)
    print("out", out)

    expected_out = gmm(a_ref, b_ref, batch_sizes, trans_b)

    assert torch.allclose(out, expected_out, atol=1e-3), f"Expected {expected_out}, got {out}"

    out.sum().backward()

    expected_out.sum().backward()
    assert torch.allclose(a.grad, a_ref.grad, atol=1e-3), f"Expected {a_ref.grad}, got {a.grad}"
    assert torch.allclose(b.grad, b_ref.grad, atol=1e-3), f"Expected {b_ref.grad}, got {b.grad}"
    print("Test passed successfully!")