kernel
File size: 2,397 Bytes
9c4ca75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

from megablocks import ops

_BINNED_SCATTER_TESTS = (
    (4, 2, 2, 1),
    (4, 2, 2, 2),
    (4, 2, 2, 4),
    (1024, 1536, 4, 1),
    (1024, 1536, 4, 2),
    (1024, 1536, 4, 4),
    (1024, 1536, 64, 1),
    (1024, 1536, 64, 2),
    (1024, 1536, 64, 4),
    (1024, 1536, 128, 1),
    (1024, 1536, 128, 2),
    (1024, 1536, 128, 4),
    (16384, 768, 4, 1),
    (16384, 768, 4, 2),
    (16384, 768, 4, 4),
    (16384, 768, 64, 1),
    (16384, 768, 64, 2),
    (16384, 768, 64, 4),
    (16384, 768, 128, 1),
    (16384, 768, 128, 2),
    (16384, 768, 128, 4),
)


@pytest.mark.gpu
@pytest.mark.parametrize(('sl', 'hs', 'ne', 'top_k'), _BINNED_SCATTER_TESTS)
def testBinnedScatter(sl: int, hs: int, ne: int, top_k: int):
    # NOTE: Capacity factor == 1.
    ec = (sl * top_k) // ne

    # Create the data and indices.
    x = torch.randn((sl, hs)).cuda().half()

    # Randomly assign tokens to experts.
    top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
    _, indices = ops.sort(top_expert)
    bins = ops.inclusive_cumsum(ops.histogram(top_expert, ne), 0)

    # Sample weights for the scatter reduce.
    weights = torch.rand((sl * top_k,)).cuda().half()

    x = ops.binned_gather(x, indices, bins, ec, top_k)

    def binned_scatter(
        x: torch.Tensor,
        indices: torch.Tensor,
        weights: torch.Tensor,
        bins: torch.Tensor,
        top_k: int,
    ):
        x = x.cpu().numpy()
        indices = indices.cpu().numpy()
        weights = weights.cpu().numpy()
        bins = bins.cpu().numpy()
        start = 0
        out = np.zeros((sl, hs))
        for i in range(ne):
            end = bins[i]
            for j in range(min(ec, end - start)):
                index = indices[start + j]
                scale = weights[index]
                index //= top_k

                out[index, :] += scale * x[i, j, :]
            start = end
        return torch.from_numpy(out).cuda().half()

    out = ops.binned_scatter(x, indices, weights, bins, top_k)
    expected_out = binned_scatter(x, indices, weights, bins, top_k)

    # NOTE: We need to check approximate equality because the
    # scatter reduce uses atomics.
    assert np.testing.assert_allclose(
        out.cpu(),
        expected_out.cpu(),
        rtol=5e-3,
    ) is None