kernel
File size: 2,706 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
#undef CUB_WRAPPED_NAMESPACE
#define CUB_WRAPPED_NAMESPACE megablocks

#include <cstdint>

#include <cub/cub.cuh>
#include <c10/cuda/CUDAStream.h>
// #include <torch/extension.h>

#define CUDA_CALL(code)					    \
  do {                                                      \
    cudaError_t status = code;                              \
    std::string err = cudaGetErrorString(status);           \
    TORCH_CHECK(status == cudaSuccess, err);		    \
  } while (0)

namespace megablocks {

template <typename T>
torch::Tensor cub_histogram(torch::Tensor x, int num_bins) {
  // Allocate the count buffer.
  auto options = torch::TensorOptions()
    .dtype(torch::kInt32)
    .device(x.device());
  torch::Tensor out = torch::empty({x.size(0), num_bins}, options);

  // Exit early if there is not work to do.
  if (out.numel() == 0) return out;

  // Get scratchpad size.
  size_t scratchpad_bytes = 0;
  CUDA_CALL(cub::DeviceHistogram::HistogramEven(nullptr,
						scratchpad_bytes,
						x.data_ptr<T>(),
						out.data_ptr<int>(),
						/*num_levels=*/num_bins + 1,
						/*lower_level=*/0,
						/*upper_level=*/num_bins,
						/*num_samples=*/int(x.size(1)),
						c10::cuda::getCurrentCUDAStream()));

  // Allocate scratchpad.
  options = torch::TensorOptions().dtype(torch::kInt8).device(x.device());
  torch::Tensor scratchpad = torch::empty(scratchpad_bytes, options);

  // Run the kernel.
  for (int i = 0; i < x.size(0); ++i) {
    CUDA_CALL(cub::DeviceHistogram::HistogramEven(scratchpad.data_ptr(),
						  scratchpad_bytes,
						  x.data_ptr<T>() + x.size(1) * i,
						  out.data_ptr<int>() + out.size(1) * i,
						  /*num_levels=*/num_bins + 1,
						  /*lower_level=*/0,
						  /*upper_level=*/num_bins,
						  /*num_samples=*/int(x.size(1)),
						  c10::cuda::getCurrentCUDAStream()));
  }
  return out;
}

torch::Tensor histogram(torch::Tensor x, int num_bins) {
  TORCH_CHECK(x.is_cuda());
  TORCH_CHECK(x.ndimension() == 1 || x.ndimension() == 2);
  TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
	      x.scalar_type() == torch::kInt32 ||
	      x.scalar_type() == torch::kInt64);
  bool no_batch = x.ndimension() == 1;
  if (no_batch) x = x.view({1, x.numel()});

  if (x.scalar_type() == torch::kInt16) {
    auto out = cub_histogram<short>(x, num_bins);
    return no_batch ? out.flatten() : out;
  } else if (x.scalar_type() == torch::kInt32) {
    auto out = cub_histogram<int>(x, num_bins);
    return no_batch ? out.flatten() : out;
  } else {
    TORCH_CHECK(x.scalar_type() == torch::kInt64);
    auto out = cub_histogram<long>(x, num_bins);
    return no_batch ? out.flatten() : out;
  }
}

}  // namespace megablocks

#undef CUDA_CALL
#undef CUB_WRAPPED_NAMESPACE