|
#define CUB_IGNORE_DEPRECATED_API |
|
|
|
#undef CUB_WRAPPED_NAMESPACE |
|
#define CUB_WRAPPED_NAMESPACE megablocks |
|
|
|
#include <cstdint> |
|
|
|
#include <cub/cub.cuh> |
|
#include <c10/cuda/CUDAStream.h> |
|
#include <torch/all.h> |
|
|
|
|
|
#define CUDA_CALL(code) \ |
|
do { \ |
|
cudaError_t status = code; \ |
|
std::string err = cudaGetErrorString(status); \ |
|
TORCH_CHECK(status == cudaSuccess, err); \ |
|
} while (0) |
|
|
|
namespace megablocks { |
|
|
|
struct Inclusive {}; |
|
struct Exclusive {}; |
|
|
|
template <typename Type> struct Cumsum { |
|
|
|
template< |
|
typename InputIteratorT, |
|
typename OutputIteratorT> |
|
static void Run(void * d_temp_storage, |
|
size_t & temp_storage_bytes, |
|
InputIteratorT d_in, |
|
OutputIteratorT d_out, |
|
int num_items, |
|
cudaStream_t stream = 0, |
|
bool debug_synchronous = false) { |
|
CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage, |
|
temp_storage_bytes, |
|
d_in, |
|
d_out, |
|
num_items, |
|
stream)); |
|
|
|
} |
|
}; |
|
|
|
template <> struct Cumsum<Inclusive> { |
|
template< |
|
typename InputIteratorT, |
|
typename OutputIteratorT> |
|
static void Run(void * d_temp_storage, |
|
size_t & temp_storage_bytes, |
|
InputIteratorT d_in, |
|
OutputIteratorT d_out, |
|
int num_items, |
|
cudaStream_t stream = 0, |
|
bool debug_synchronous = false) { |
|
CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage, |
|
temp_storage_bytes, |
|
d_in, |
|
d_out, |
|
num_items, |
|
stream)); |
|
|
|
} |
|
}; |
|
|
|
template <typename SumType, typename T> |
|
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) { |
|
|
|
size_t scratchpad_bytes = 0; |
|
Cumsum<SumType>::Run(nullptr, |
|
scratchpad_bytes, |
|
x.data_ptr<T>(), |
|
out.data_ptr<T>(), |
|
x.size(1), |
|
c10::cuda::getCurrentCUDAStream()); |
|
|
|
|
|
|
|
|
|
auto options = torch::TensorOptions() |
|
.dtype(torch::kInt8) |
|
.device(x.device()); |
|
torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0), |
|
options); |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
for (int i = 0; i < x.size(0); ++i) { |
|
void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i; |
|
Cumsum<SumType>::Run(scratchpad_ptr, |
|
scratchpad_bytes, |
|
x.data_ptr<T>() + x.size(1) * i, |
|
out.data_ptr<T>() + x.size(1) * i, |
|
x.size(1), |
|
c10::cuda::getCurrentCUDAStream()); |
|
} |
|
} |
|
|
|
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { |
|
|
|
TORCH_CHECK(x.is_cuda()); |
|
TORCH_CHECK(x.ndimension() == 2); |
|
TORCH_CHECK(x.scalar_type() == torch::kInt16 || |
|
x.scalar_type() == torch::kInt32 || |
|
x.scalar_type() == torch::kInt64); |
|
TORCH_CHECK(out.is_cuda()); |
|
TORCH_CHECK(out.ndimension() == 2); |
|
TORCH_CHECK(out.scalar_type() == x.scalar_type()); |
|
|
|
|
|
|
|
TORCH_CHECK(dim == 1); |
|
|
|
switch (x.scalar_type()) { |
|
case torch::kInt16: |
|
cub_cumsum<Exclusive, short>(x, dim, out); |
|
return; |
|
case torch::kInt32: |
|
cub_cumsum<Exclusive, int>(x, dim, out); |
|
return; |
|
} |
|
TORCH_CHECK(x.scalar_type() == torch::kInt64); |
|
cub_cumsum<Exclusive, long>(x, dim, out); |
|
} |
|
|
|
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) { |
|
|
|
TORCH_CHECK(x.is_cuda()); |
|
TORCH_CHECK(x.ndimension() == 2); |
|
TORCH_CHECK(x.scalar_type() == torch::kInt16 || |
|
x.scalar_type() == torch::kInt32 || |
|
x.scalar_type() == torch::kInt64); |
|
TORCH_CHECK(out.is_cuda()); |
|
TORCH_CHECK(out.ndimension() == 2); |
|
TORCH_CHECK(out.scalar_type() == x.scalar_type()); |
|
|
|
|
|
|
|
TORCH_CHECK(dim == 1); |
|
|
|
switch (x.scalar_type()) { |
|
case torch::kInt16: |
|
cub_cumsum<Inclusive, short>(x, dim, out); |
|
return; |
|
case torch::kInt32: |
|
cub_cumsum<Inclusive, int>(x, dim, out); |
|
return; |
|
} |
|
TORCH_CHECK(x.scalar_type() == torch::kInt64); |
|
cub_cumsum<Inclusive, long>(x, dim, out); |
|
} |
|
|
|
} |
|
|
|
#undef CUB_WRAPPED_NAMESPACE |