kernel
File size: 4,723 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
#define CUB_IGNORE_DEPRECATED_API

#undef CUB_WRAPPED_NAMESPACE
#define CUB_WRAPPED_NAMESPACE megablocks

#include <cstdint>

#include <cub/cub.cuh>
#include <c10/cuda/CUDAStream.h>
#include <torch/all.h>
// #include <torch/extension.h>

#define CUDA_CALL(code)					    \
  do {                                                      \
    cudaError_t status = code;                              \
    std::string err = cudaGetErrorString(status);           \
    TORCH_CHECK(status == cudaSuccess, err);		    \
  } while (0)

namespace megablocks {

struct Inclusive {};
struct Exclusive {};

template <typename Type> struct Cumsum {

  template<
    typename InputIteratorT,
    typename OutputIteratorT>
  static void Run(void * d_temp_storage,
		  size_t & temp_storage_bytes,
		  InputIteratorT d_in,
		  OutputIteratorT d_out,
		  int num_items,
		  cudaStream_t stream = 0,
		  bool debug_synchronous = false) {
    CUDA_CALL(cub::DeviceScan::ExclusiveSum(d_temp_storage,
					    temp_storage_bytes,
					    d_in,
					    d_out,
					    num_items,
					    stream));//,
					    //debug_synchronous));
  }
};

template <> struct Cumsum<Inclusive> {
  template<
    typename InputIteratorT,
    typename OutputIteratorT>
  static void Run(void * d_temp_storage,
		  size_t & temp_storage_bytes,
		  InputIteratorT d_in,
		  OutputIteratorT d_out,
		  int num_items,
		  cudaStream_t stream = 0,
		  bool debug_synchronous = false) {
    CUDA_CALL(cub::DeviceScan::InclusiveSum(d_temp_storage,
					    temp_storage_bytes,
					    d_in,
					    d_out,
					    num_items,
					    stream));//,
					    //debug_synchronous));
  }
};

template <typename SumType, typename T>
void cub_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
  // Get temporary storage size.
  size_t scratchpad_bytes = 0;
  Cumsum<SumType>::Run(nullptr,
		       scratchpad_bytes,
		       x.data_ptr<T>(),
		       out.data_ptr<T>(),
		       x.size(1),
		       c10::cuda::getCurrentCUDAStream());

  // Allocate scratchpad.
  //
  // NOTE: We scale for the batch dimension so we can run in parallel.
  auto options = torch::TensorOptions()
    .dtype(torch::kInt8)
    .device(x.device());
  torch::Tensor scratchpad = torch::empty(scratchpad_bytes * x.size(0),
  					  options);

  // Run the kernel.
  //
  // NOTE: Using different streams for each issue does not appear to
  // yield performance gains for our problem set. The overhead of
  // event/stream synchronization appears to outweigh the benfits.
  // We could write a true batched cumsum, but this would require
  // significant code duplication from cub and we might move away
  // from this formulation anyways.
  for (int i = 0; i < x.size(0); ++i) {
    void* scratchpad_ptr = (int8_t*)scratchpad.data_ptr() + scratchpad_bytes * i;
    Cumsum<SumType>::Run(scratchpad_ptr,
			 scratchpad_bytes,
			 x.data_ptr<T>() + x.size(1) * i,
			 out.data_ptr<T>() + x.size(1) * i,
			 x.size(1),
			 c10::cuda::getCurrentCUDAStream());
  }
}

void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
  // Validate the input matrix.
  TORCH_CHECK(x.is_cuda());
  TORCH_CHECK(x.ndimension() == 2);
  TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
	      x.scalar_type() == torch::kInt32 ||
	      x.scalar_type() == torch::kInt64);
  TORCH_CHECK(out.is_cuda());
  TORCH_CHECK(out.ndimension() == 2);
  TORCH_CHECK(out.scalar_type() == x.scalar_type());

  // NOTE: We currently only support contraction across the contiguous
  // dimension in the matrix.
  TORCH_CHECK(dim == 1);

  switch (x.scalar_type()) {
  case torch::kInt16:
    cub_cumsum<Exclusive, short>(x, dim, out);
    return;
  case torch::kInt32:
    cub_cumsum<Exclusive, int>(x, dim, out);
    return;
  }
  TORCH_CHECK(x.scalar_type() == torch::kInt64);
  cub_cumsum<Exclusive, long>(x, dim, out);
}

void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
  // Validate the input matrix.
  TORCH_CHECK(x.is_cuda());
  TORCH_CHECK(x.ndimension() == 2);
  TORCH_CHECK(x.scalar_type() == torch::kInt16 ||
	      x.scalar_type() == torch::kInt32 ||
	      x.scalar_type() == torch::kInt64);
  TORCH_CHECK(out.is_cuda());
  TORCH_CHECK(out.ndimension() == 2);
  TORCH_CHECK(out.scalar_type() == x.scalar_type());

  // NOTE: We currently only support contraction across the contiguous
  // dimension in the matrix.
  TORCH_CHECK(dim == 1);

  switch (x.scalar_type()) {
  case torch::kInt16:
    cub_cumsum<Inclusive, short>(x, dim, out);
    return;
  case torch::kInt32:
    cub_cumsum<Inclusive, int>(x, dim, out);
    return;
  }
  TORCH_CHECK(x.scalar_type() == torch::kInt64);
  cub_cumsum<Inclusive, long>(x, dim, out);
}

} // namespace megablocks

#undef CUB_WRAPPED_NAMESPACE