|
#undef CUB_WRAPPED_NAMESPACE |
|
#define CUB_WRAPPED_NAMESPACE megablocks |
|
|
|
#include "new_replicate.h" |
|
#include <cstdint> |
|
#include <cub/cub.cuh> |
|
#include <c10/util/Half.h> |
|
#include <c10/cuda/CUDAStream.h> |
|
|
|
#define CUDA_CALL(code) \ |
|
do { \ |
|
cudaError_t status = code; \ |
|
std::string err = cudaGetErrorString(status); \ |
|
TORCH_CHECK(status == cudaSuccess, err); \ |
|
} while (0) |
|
|
|
namespace megablocks { |
|
namespace replicate { |
|
|
|
template <typename T, int kThreadsPerBlock> |
|
__global__ void __launch_bounds__(kThreadsPerBlock) |
|
ReplicateForwardKernel(T * __restrict__ x, |
|
int * __restrict__ bins, |
|
T * __restrict__ out, |
|
int columns) { |
|
// Offset to this threadblocks batch. |
|
// |
|
// x is [batch_size, num_bins] |
|
// out is [batch_size, columns] |
|
// bins is [num_bins] |
|
int batch_idx = blockIdx.y; |
|
int num_bins = gridDim.x; |
|
x += batch_idx * num_bins; |
|
out += batch_idx * columns; |
|
|
|
// Load the start/end for this bin. |
|
int bin_idx = blockIdx.x; |
|
int start = 0; |
|
if (bin_idx > 0) start = __ldg(bins + bin_idx - 1); |
|
int end = __ldg(bins + bin_idx); |
|
|
|
// Load the value to replicate. |
|
T value = __ldg((T*)x + bin_idx); |
|
|
|
// Offset to this threadblocks bin and this threads |
|
// offset within the bin. |
|
int bin_offset = blockIdx.z * kThreadsPerBlock + threadIdx.x; |
|
out += start + bin_offset; |
|
|
|
// Replicate the value to the output. |
|
// |
|
// TODO(tgale): Vectorize these stores. |
|
int num_elements = end - start; |
|
const int kElementsPerLoop = gridDim.z * kThreadsPerBlock; |
|
T *out_ptr = (T*)out; |
|
for (; bin_offset < num_elements; num_elements -= kElementsPerLoop) { |
|
*out_ptr = value; |
|
out_ptr += kElementsPerLoop; |
|
} |
|
} |
|
|
|
template <typename T> |
|
cudaError_t ReplicateForward(T *x, |
|
int batch_size, |
|
int num_bins, |
|
int *bins, |
|
T *out, |
|
int columns, |
|
cudaStream_t stream) { |
|
const int kThreadsPerBlock = 64; |
|
dim3 block_dim(kThreadsPerBlock, 1, 1); |
|
int group_size = std::ceil((float)columns / (num_bins * kThreadsPerBlock)); |
|
dim3 grid_dim(num_bins, batch_size, group_size); |
|
ReplicateForwardKernel<T, kThreadsPerBlock><<< |
|
grid_dim, block_dim, 0, stream>>>(x, bins, out, columns); |
|
return cudaGetLastError(); |
|
} |
|
|
|
void cub_segmented_reduce(torch::Tensor grad, |
|
torch::Tensor bins, |
|
torch::Tensor out, |
|
cudaStream_t stream) { |
|
// Append a zero to the bin boundaries for CUB. |
|
torch::Tensor offsets = torch::empty(bins.numel() + 1, bins.options()); |
|
CUDA_CALL(cudaMemsetAsync(offsets.data_ptr<int>(), |
|
0, |
|
offsets.numel() * sizeof(int), |
|
stream)); |
|
CUDA_CALL(cudaMemcpyAsync(offsets.data_ptr<int>() + 1, |
|
bins.data_ptr<int>(), |
|
bins.numel() * sizeof(int), |
|
cudaMemcpyDeviceToDevice, |
|
stream)); |
|
|
|
// Get temporary buffer size. |
|
size_t scratchpad_bytes = 0; |
|
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(nullptr, |
|
scratchpad_bytes, |
|
grad.data_ptr<c10::Half>(), |
|
out.data_ptr<c10::Half>(), |
|
bins.numel(), |
|
offsets.data_ptr<int>(), |
|
offsets.data_ptr<int>() + 1, |
|
stream)); |
|
|
|
// Allocate scratchpad. |
|
auto options = torch::TensorOptions() |
|
.dtype(torch::kInt8) |
|
.device(grad.device()); |
|
torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options); |
|
|
|
// Run the kernel for each batch item. |
|
for (int i = 0; i < grad.size(0); ++i) { |
|
int num_bins = out.size(1); |
|
int num_values = grad.size(1); |
|
CUDA_CALL(cub::DeviceSegmentedReduce::Sum(scratchpad.data_ptr<int8_t>(), |
|
scratchpad_bytes, |
|
grad.data_ptr<c10::Half>() + i * num_values, |
|
out.data_ptr<c10::Half>() + i * num_bins, |
|
bins.numel(), |
|
offsets.data_ptr<int>(), |
|
offsets.data_ptr<int>() + 1, |
|
stream)); |
|
} |
|
} |
|
|
|
} // namespace replicate |
|
|
|
void replicate_forward(torch::Tensor x, |
|
torch::Tensor bins, |
|
torch::Tensor out) { |
|
// Validate the inputs. |
|
TORCH_CHECK(x.is_cuda()); |
|
TORCH_CHECK(x.ndimension() == 2); |
|
TORCH_CHECK(x.scalar_type() == torch::kFloat16 || |
|
x.scalar_type() == torch::kInt16 || |
|
x.scalar_type() == torch::kInt32); |
|
TORCH_CHECK(bins.is_cuda()); |
|
TORCH_CHECK(bins.ndimension() == 1); |
|
TORCH_CHECK(bins.scalar_type() == torch::kInt); |
|
TORCH_CHECK(out.is_cuda()); |
|
TORCH_CHECK(out.ndimension() == 2); |
|
TORCH_CHECK(out.scalar_type() == x.scalar_type()); |
|
|
|
// Batch dimensions should match for input/output. |
|
TORCH_CHECK(x.size(0) == out.size(0)); |
|
|
|
// One input for each bin (in each batch). |
|
TORCH_CHECK(x.size(1) == bins.size(0)); |
|
|
|
// Exit early if there is no work to do. |
|
if (out.numel() == 0) return; |
|
|
|
switch (x.scalar_type()) { |
|
case torch::kFloat16: |
|
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<c10::Half>(), |
|
x.size(0), |
|
x.size(1), |
|
bins.data_ptr<int>(), |
|
out.data_ptr<c10::Half>(), |
|
out.size(1), |
|
c10::cuda::getCurrentCUDAStream())); |
|
return; |
|
case torch::kInt32: |
|
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<int>(), |
|
x.size(0), |
|
x.size(1), |
|
bins.data_ptr<int>(), |
|
out.data_ptr<int>(), |
|
out.size(1), |
|
c10::cuda::getCurrentCUDAStream())); |
|
return; |
|
} |
|
TORCH_CHECK(x.scalar_type() == torch::kInt16); |
|
CUDA_CALL(replicate::ReplicateForward(x.data_ptr<short>(), |
|
x.size(0), |
|
x.size(1), |
|
bins.data_ptr<int>(), |
|
out.data_ptr<short>(), |
|
out.size(1), |
|
c10::cuda::getCurrentCUDAStream())); |
|
} |
|
|
|
void replicate_backward(torch::Tensor grad, |
|
torch::Tensor bins, |
|
torch::Tensor out) { |
|
// Validate the inputs. |
|
TORCH_CHECK(grad.is_cuda()); |
|
TORCH_CHECK(grad.ndimension() == 2); |
|
TORCH_CHECK(grad.scalar_type() == torch::kFloat16); |
|
TORCH_CHECK(bins.is_cuda()); |
|
TORCH_CHECK(bins.ndimension() == 1); |
|
TORCH_CHECK(bins.scalar_type() == torch::kInt); |
|
TORCH_CHECK(out.is_cuda()); |
|
TORCH_CHECK(out.ndimension() == 2); |
|
TORCH_CHECK(out.scalar_type() == torch::kFloat16); |
|
|
|
// Batch dimensions should match for input/output. |
|
TORCH_CHECK(grad.size(0) == out.size(0)); |
|
|
|
// One output for each bin (in each batch). |
|
TORCH_CHECK(out.size(1) == bins.size(0)); |
|
|
|
replicate::cub_segmented_reduce(grad, bins, out, c10::cuda::getCurrentCUDAStream()); |
|
} |
|
|
|
} // namespace megablocks |
|
|
|
#undef CUDA_CALL |
|
#undef CUB_WRAPPED_NAMESPACE |