kernel
File size: 6,362 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
#undef CUB_WRAPPED_NAMESPACE
#define CUB_WRAPPED_NAMESPACE megablocks

#include <cstdint>

#include <cub/cub.cuh>
#include <c10/util/Half.h>
#include <c10/cuda/CUDAStream.h>
// #include <torch/extension.h>

#define CUDA_CALL(code)					    \
  do {                                                      \
    cudaError_t status = code;                              \
    std::string err = cudaGetErrorString(status);           \
    TORCH_CHECK(status == cudaSuccess, err);		    \
  } while (0)

namespace megablocks {
namespace replicate {

template <typename T, int kThreadsPerBlock>
__global__ void __launch_bounds__(kThreadsPerBlock)
  ReplicateForwardKernel(T * __restrict__ x,
			 int * __restrict__ bins,
			 T * __restrict__ out,
			 int columns) {
  // Offset to this threadblocks batch.
  //
  // x is [batch_size, num_bins]
  // out is [batch_size, columns]
  // bins is [num_bins]
  int batch_idx = blockIdx.y;
  int num_bins = gridDim.x;
  x += batch_idx * num_bins;
  out += batch_idx * columns;

  // Load the start/end for this bin.
  int bin_idx = blockIdx.x;
  int start = 0;
  if (bin_idx > 0) start = __ldg(bins + bin_idx - 1);
  int end = __ldg(bins + bin_idx);

  // Load the value to replicate.
  T value = __ldg((T*)x + bin_idx);

  // Offset to this threadblocks bin and this threads
  // offset within the bin.
  int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x;
  out += start + bin_offset;

  // Replicate the value to the output.
  //
  // TODO(tgale): Vectorize these stores.
  int num_elements = end - start;
  const int kElementsPerLoop = gridDim.z * kThreadsPerBlock;
  T *out_ptr = (T*)out;
  for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) {
    *out_ptr = value;
    out_ptr += kElementsPerLoop;
  }
}

template <typename T>
cudaError_t ReplicateForward(T *x,
			     int batch_size,
			     int num_bins,
			     int *bins,
			     T *out,
			     int columns,
			     cudaStream_t stream) {
  const int kThreadsPerBlock = 64;
  dim3 block_dim(kThreadsPerBlock, 1, 1);
  int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock));
  dim3 grid_dim(num_bins, batch_size, group_size);
  ReplicateForwardKernel<T, kThreadsPerBlock><<<
    grid_dim, block_dim, 0, stream>>>(x, bins, out, columns);
  return cudaGetLastError();
}

void cub_segmented_reduce(torch::Tensor grad,
			  torch::Tensor bins,
			  torch::Tensor out,
			  cudaStream_t stream) {
  // Append a zero to the bin boundaries for CUB.
  torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options());
  CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(),
			    0,
			    offsets.numel() * sizeof(int),
			    stream));
  CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1,
			    bins.data_ptr<int>(),
			    bins.numel() * sizeof(int),
			    cudaMemcpyDeviceToDevice,
			    stream));

  // Get temporary buffer size.
  size_t scratchpad_bytes = 0;
  CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr,
					    scratchpad_bytes,
					    grad.data_ptr<c10::Half>(),
					    out.data_ptr<c10::Half>(),
					    bins.numel(),
					    offsets.data_ptr<int>(),
					    offsets.data_ptr<int>() + 1,
					    stream));

  // Allocate scratchpad.
  auto options = torch::TensorOptions()
    .dtype(torch::kInt8)
    .device(grad.device());
  torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);

  // Run the kernel for each batch item.
  for (int i = 0; i < grad.size(0); ++i) {
    int num_bins = out.size(1);
    int num_values = grad.size(1);
    CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(),
					      scratchpad_bytes,
					      grad.data_ptr<c10::Half>() + i * num_values,
					      out.data_ptr<c10::Half>() + i * num_bins,
					      bins.numel(),
					      offsets.data_ptr<int>(),
					      offsets.data_ptr<int>() + 1,
					      stream));
  }
}

}  // namespace replicate

void replicate_forward(torch::Tensor x,
		       torch::Tensor bins,
		       torch::Tensor out) {
  // Validate the inputs.
  TORCH_CHECK(x.is_cuda());
  TORCH_CHECK(x.ndimension() == 2);
  TORCH_CHECK(x.scalar_type() == torch::kFloat16 ||
	      x.scalar_type() == torch::kInt16 ||
	      x.scalar_type() == torch::kInt32);
  TORCH_CHECK(bins.is_cuda());
  TORCH_CHECK(bins.ndimension() == 1);
  TORCH_CHECK(bins.scalar_type() == torch::kInt);
  TORCH_CHECK(out.is_cuda());
  TORCH_CHECK(out.ndimension() == 2);
  TORCH_CHECK(out.scalar_type() == x.scalar_type());

  // Batch dimensions should match for input/output.
  TORCH_CHECK(x.size(0) == out.size(0));

  // One input for each bin (in each batch).
  TORCH_CHECK(x.size(1) == bins.size(0));

  // Exit early if there is no work to do.
  if (out.numel() == 0) return;

  switch (x.scalar_type()) {
  case torch::kFloat16:
    CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(),
					  x.size(0),
					  x.size(1),
					  bins.data_ptr<int>(),
					  out.data_ptr<c10::Half>(),
					  out.size(1),
					  c10::cuda::getCurrentCUDAStream()));
    return;
  case torch::kInt32:
    CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(),
					  x.size(0),
					  x.size(1),
					  bins.data_ptr<int>(),
					  out.data_ptr<int>(),
					  out.size(1),
					  c10::cuda::getCurrentCUDAStream()));
    return;
  }
  TORCH_CHECK(x.scalar_type() == torch::kInt16);
  CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(),
					x.size(0),
					x.size(1),
					bins.data_ptr<int>(),
					out.data_ptr<short>(),
					out.size(1),
					c10::cuda::getCurrentCUDAStream()));
}

void replicate_backward(torch::Tensor grad,
			torch::Tensor bins,
			torch::Tensor out) {
  // Validate the inputs.
  TORCH_CHECK(grad.is_cuda());
  TORCH_CHECK(grad.ndimension() == 2);
  TORCH_CHECK(grad.scalar_type() == torch::kFloat16);
  TORCH_CHECK(bins.is_cuda());
  TORCH_CHECK(bins.ndimension() == 1);
  TORCH_CHECK(bins.scalar_type() == torch::kInt);
  TORCH_CHECK(out.is_cuda());
  TORCH_CHECK(out.ndimension() == 2);
  TORCH_CHECK(out.scalar_type() == torch::kFloat16);

  // Batch dimensions should match for input/output.
  TORCH_CHECK(grad.size(0) == out.size(0));

  // One output for each bin (in each batch).
  TORCH_CHECK(out.size(1) == bins.size(0));

  replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream());
}

}  // namespace megablocks

#undef CUDA_CALL
#undef CUB_WRAPPED_NAMESPACE