kernel
File size: 3,170 Bytes
9c4ca75
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0

import numpy as np
import pytest
import torch

try:
    from megablocks._ops import ops as backend  # type: ignore
except ModuleNotFoundError as e:
    raise ModuleNotFoundError("No module named 'megablocks_ops'.") from e

from megablocks import ops


def promote_scalar(x: torch.Tensor) -> torch.Tensor:
    return x.view(1) if not len(x.size()) else x


REPLICATE_TESTS = [
    (8, 1, 1),
    (8, 2, 1),
    (8, 4, 1),
    (8, 8, 1),
    (8, 2, 2),
    (8, 4, 2),
    (8, 8, 2),
    (8, 2, 4),
    (8, 4, 4),
    (8, 8, 4),
    (8, 2, 8),
    (8, 4, 8),
    (8, 8, 8),
    (16384, 2, 1),
    (16384, 4, 1),
    (16384, 8, 1),
    (16384, 16, 1),
    (16384, 32, 1),
    (16384, 64, 1),
    (16384, 128, 1),
    (16384, 2, 2),
    (16384, 4, 2),
    (16384, 8, 2),
    (16384, 16, 2),
    (16384, 32, 2),
    (16384, 64, 2),
    (16384, 128, 2),
    (16384, 2, 4),
    (16384, 4, 4),
    (16384, 8, 4),
    (16384, 16, 4),
    (16384, 32, 4),
    (16384, 64, 4),
    (16384, 128, 4),
    (16384, 2, 8),
    (16384, 4, 8),
    (16384, 8, 8),
    (16384, 16, 8),
    (16384, 32, 8),
    (16384, 64, 8),
    (16384, 128, 8),
]


@pytest.mark.gpu
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
def test_replicate(tokens: int, num_centers: int, top_k: int):
    tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
    tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
    bins = ops.inclusive_cumsum(tokens_per_center, 0)
    bins = promote_scalar(bins)
    center_weights = torch.randn(top_k, num_centers).cuda().half()

    def replicate(x: torch.Tensor, bins: torch.Tensor, num_outputs: int):
        x = x.cpu().numpy()
        bins = bins.cpu().numpy()
        out = np.zeros((x.shape[0], num_outputs))
        for batch_idx in range(x.shape[0]):
            start = 0
            for i, end in enumerate(bins):
                value = x[batch_idx, i]
                while start < end:
                    out[batch_idx, start] = value
                    start += 1
        return torch.from_numpy(out).cuda().half()

    out = ops.replicate(center_weights, bins, tokens)
    expected_out = replicate(center_weights, bins, tokens)
    assert torch.all(torch.eq(out, expected_out))


@pytest.mark.gpu
@pytest.mark.parametrize(("tokens", "num_centers", "top_k"), REPLICATE_TESTS)
def test_replicate_backward(tokens: int, num_centers: int, top_k: int):
    tokens_to_centers = torch.randint(0, num_centers, (tokens,)).cuda().int()
    tokens_per_center = ops.histogram(tokens_to_centers, num_centers)
    bins = ops.inclusive_cumsum(tokens_per_center, 0)
    bins = promote_scalar(bins)
    center_weights = torch.randn(top_k, num_centers).cuda().half()

    grad = ops.replicate(center_weights, bins, tokens)

    out = torch.empty_like(center_weights)
    backend.replicate_backward(grad, bins, out)
    expected_out = center_weights * tokens_per_center.view([1, num_centers])

    # NOTE: This floating-point reduction could be a problem for training stability and accuracy.
    assert torch.allclose(out, expected_out, rtol=1e-2)