|
#include <cstdint> |
|
#include <c10/util/Half.h> |
|
|
|
#include <c10/cuda/CUDAStream.h> |
|
|
|
#define CUDA_CALL(code) \ |
|
do { \ |
|
cudaError_t status = code; \ |
|
std::string err = cudaGetErrorString(status); \ |
|
TORCH_CHECK(status == cudaSuccess, err); \ |
|
} while (0) |
|
|
|
namespace megablocks { |
|
namespace construct_indices { |
|
|
|
|
|
|
|
|
|
const int kThreadsPerBlock = 32; |
|
|
|
__global__ void __launch_bounds__(kThreadsPerBlock) |
|
ConstructIndicesKernel(short * __restrict__ indices, |
|
int num_columns, |
|
int block_size, |
|
const int * __restrict__ padded_bins) { |
|
|
|
int start = 0; |
|
if (blockIdx.x > 0) start = __ldg(padded_bins + blockIdx.x - 1); |
|
int end = __ldg(padded_bins + blockIdx.x); |
|
|
|
|
|
start /= block_size; |
|
end /= block_size; |
|
|
|
|
|
indices += (start + blockIdx.y) * num_columns + threadIdx.x; |
|
|
|
|
|
int bin_offset = blockIdx.y; |
|
int num_rows = end - start; |
|
for (; bin_offset < num_rows; num_rows -= gridDim.y) { |
|
short *out = indices; |
|
for (int bid = threadIdx.x; bid < num_columns; bid += kThreadsPerBlock) { |
|
*out = bid + (blockIdx.x * num_columns); |
|
out += kThreadsPerBlock; |
|
} |
|
indices += gridDim.y * num_columns; |
|
} |
|
} |
|
|
|
cudaError_t ConstructIndices(short * __restrict__ indices, |
|
int output_block_rows, |
|
int output_block_columns, |
|
int block_size, |
|
const int * __restrict__ padded_bins, |
|
int num_bins, |
|
cudaStream_t stream) { |
|
dim3 block_dim(kThreadsPerBlock); |
|
dim3 grid_dim(num_bins, (int)std::ceil((float)output_block_rows / num_bins)); |
|
ConstructIndicesKernel<<<grid_dim, block_dim, 0, stream>>>(indices, |
|
output_block_columns, |
|
block_size, |
|
padded_bins); |
|
return cudaGetLastError(); |
|
} |
|
|
|
} |
|
|
|
void indices(torch::Tensor padded_bins, |
|
int block_size, |
|
int output_block_rows, |
|
int output_block_columns, |
|
torch::Tensor out) { |
|
TORCH_CHECK(padded_bins.is_cuda()); |
|
TORCH_CHECK(padded_bins.ndimension() == 1); |
|
TORCH_CHECK(padded_bins.scalar_type() == torch::kInt); |
|
|
|
TORCH_CHECK(out.is_cuda()); |
|
TORCH_CHECK(out.ndimension() == 1); |
|
TORCH_CHECK(out.scalar_type() == torch::kInt16); |
|
TORCH_CHECK(out.numel() == (output_block_rows * output_block_columns)); |
|
|
|
|
|
if (out.numel() == 0) return; |
|
|
|
CUDA_CALL(construct_indices::ConstructIndices(out.data_ptr<short>(), |
|
output_block_rows, |
|
output_block_columns, |
|
block_size, |
|
padded_bins.data_ptr<int>(), |
|
padded_bins.numel(), |
|
c10::cuda::getCurrentCUDAStream())); |
|
} |
|
|
|
} |
|
|