File size: 4,141 Bytes
9c4ca75 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 |
# Copyright 2024 Databricks
# SPDX-License-Identifier: Apache-2.0
import numpy as np
import pytest
import torch
from megablocks import ops
PADDED_SCATTER_TESTS = [
(4, 2, 2, 2),
(4, 2, 2, 1),
(4, 2, 2, 1),
(4, 2, 2, 1),
(4, 2, 2, 2),
(4, 2, 2, 2),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 4, 1),
(1024, 1, 4, 2),
(1024, 1, 4, 4),
(1024, 1, 64, 1),
(1024, 1, 64, 2),
(1024, 1, 64, 4),
(1024, 1, 128, 1),
(1024, 1, 128, 2),
(1024, 1, 128, 4),
(1024, 1536, 4, 1),
(1024, 1536, 4, 2),
(1024, 1536, 4, 4),
(1024, 1536, 4, 4),
(1024, 1536, 4, 4),
(1024, 1536, 64, 1),
(1024, 1536, 64, 2),
(1024, 1536, 64, 4),
(1024, 1536, 128, 1),
(1024, 1536, 128, 2),
(1024, 1536, 128, 4),
(1024, 1536, 128, 1),
(1024, 1536, 128, 1),
(16384, 768, 4, 1),
(16384, 768, 4, 2),
(16384, 768, 4, 4),
(16384, 768, 64, 1),
(16384, 768, 64, 2),
(16384, 768, 64, 4),
(16384, 768, 128, 1),
(16384, 768, 128, 2),
(16384, 768, 128, 4),
(16384, 1, 4, 1),
(16384, 1, 4, 2),
(16384, 1, 4, 4),
(16384, 1, 64, 1),
(16384, 1, 64, 2),
(16384, 1, 64, 4),
(16384, 1, 128, 1),
(16384, 1, 128, 2),
(16384, 1, 128, 4),
(16384, 1, 128, 2),
(16384, 1, 128, 2),
]
def _to_numpy(x: torch.Tensor) -> np.ndarray:
return x.detach().cpu().numpy()
@pytest.mark.gpu
@pytest.mark.parametrize((
'sl',
'hs',
'ne',
'top_k',
), PADDED_SCATTER_TESTS)
def testPaddedScatter(sl: int, hs: int, ne: int, top_k: int):
# Create the data and indices.
x = torch.randn((sl, hs), requires_grad=True).cuda().half()
# Randomly assign tokens to experts.
top_expert = torch.randint(0, ne, (sl * top_k,)).cuda().int()
bin_ids, indices = ops.sort(top_expert)
tokens_per_expert = ops.histogram(top_expert, ne)
padded_tokens_per_expert = ops.round_up(tokens_per_expert, 128)
padded_bins = ops.inclusive_cumsum(padded_tokens_per_expert, 0)
bins = ops.inclusive_cumsum(tokens_per_expert, 0)
# Sample weights for the scatter reduce.
weights = torch.rand((sl * top_k,), requires_grad=True).cuda().half()
# Gather the data to prepare for backwards.
x = ops.padded_gather(x, indices, bin_ids, bins, padded_bins, top_k)
def padded_scatter(
x: torch.Tensor,
indices: torch.Tensor,
bin_ids: torch.Tensor,
weights: torch.Tensor,
bins: torch.Tensor,
padded_bins: torch.Tensor,
top_k: int,
):
x = x.detach().cpu().numpy()
indices: np.ndarray = _to_numpy(indices)
bin_ids: np.ndarray = _to_numpy(bin_ids)
weights: np.ndarray = _to_numpy(weights)
bins: np.ndarray = _to_numpy(bins)
padded_bins: np.ndarray = _to_numpy(padded_bins)
out = np.zeros((indices.shape[0] // top_k, hs))
out_idx = 0
for i in range(len(bins)):
in_idx = 0 if i == 0 else padded_bins[i - 1]
end = bins[i]
while out_idx < end:
store_idx = indices[out_idx]
scale = weights[store_idx]
store_idx //= top_k
out[store_idx, :] += scale * x[in_idx, :]
out_idx += 1
in_idx += 1
return torch.from_numpy(out).cuda().half()
out = ops.padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
top_k,
)
expected_out = padded_scatter(
x,
indices,
bin_ids,
weights,
bins,
padded_bins,
top_k,
)
out.backward(torch.randn_like(out)) # sanity check backward pass
# NOTE: We need to check approximate equality because the scatter reduce uses atomics.
# np.testing.assert_allclose returns `None` if no error and raises an AssertionError if an error exists
assert np.testing.assert_allclose(
_to_numpy(out),
_to_numpy(expected_out),
rtol=5e-3,
) is None
|