kernel
File size: 4,596 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
3224250
 
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c4ca75
 
2595c46
 
 
 
 
 
9c4ca75
 
 
 
 
 
 
 
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3224250
 
 
 
 
 
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9c4ca75
2595c46
 
9c4ca75
 
2595c46
 
 
 
 
 
3224250
 
 
 
2595c46
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
#include <torch/library.h>

#include "registration.h"
#include "torch_binding.h"

#include "new_cumsum.h"
#include "new_histogram.h"
#include "new_indices.h"
#include "new_replicate.h"
#include "new_sort.h"

#include "grouped_gemm/grouped_gemm.h"

// void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
torch::Tensor exclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
  megablocks::exclusive_cumsum(x, dim, out);
  return out;
}

// void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out) {
torch::Tensor inclusive_cumsum_wrapper(torch::Tensor x, int64_t dim, torch::Tensor out) {
  megablocks::inclusive_cumsum(x, dim, out);
  return out;
}

// torch::Tensor histogram(torch::Tensor x, int num_bins);
torch::Tensor histogram_wrapper(torch::Tensor x, int64_t num_bins) {
  return megablocks::histogram(x, num_bins);
}

// void indices(torch::Tensor padded_bins,
//   int block_size,
//   int output_block_rows,
//   int output_block_columns,
//   torch::Tensor out);
torch::Tensor indices_wrapper(torch::Tensor padded_bins,
                               int64_t block_size,
                               int64_t output_block_rows,
                               int64_t output_block_columns,
                               torch::Tensor out) {
  megablocks::indices(padded_bins, block_size, output_block_rows, output_block_columns, out);
  return out;
}



// Forward pass: replicate values from x according to bin sizes
// void replicate_forward(torch::Tensor x,
//   torch::Tensor bins,
//   torch::Tensor out);
torch::Tensor replicate_forward_wrapper(torch::Tensor x, torch::Tensor bins, torch::Tensor out) {
  megablocks::replicate_forward(x, bins, out);
  return out;
}

// // Backward pass: reduce gradients back to bins using segmented reduction
// void replicate_backward(torch::Tensor grad,
//    torch::Tensor bins,
//    torch::Tensor out);
torch::Tensor replicate_backward_wrapper(torch::Tensor grad, torch::Tensor bins, torch::Tensor out) {
  megablocks::replicate_backward(grad, bins, out);
  return out;
}

// // Public interface function for radix sorting with indices
// void sort(torch::Tensor x,
//   int end_bit,
//   torch::Tensor x_out,
//   torch::Tensor iota_out);
torch::Tensor sort_wrapper(torch::Tensor x, int64_t end_bit, torch::Tensor x_out, torch::Tensor iota_out) {
  megablocks::sort(x, end_bit, x_out, iota_out);
  return x_out;
}

// GroupedGemm operation
torch::Tensor gmm(torch::Tensor a, torch::Tensor b, torch::Tensor c, torch::Tensor batch_sizes, bool trans_a, bool trans_b) {
  grouped_gemm::GroupedGemm(a, b, c, batch_sizes, trans_a, trans_b);
  return c;
}

// Reference implementation:
//
// m.def("exclusive_cumsum", &exclusive_cumsum, "batched exclusive cumsum.");
// m.def("histogram", &histogram, "even width histogram.");
// m.def("inclusive_cumsum", &inclusive_cumsum, "batched inclusive cumsum");
// m.def("indices", &indices, "indices construction for sparse matrix.");
// m.def("replicate_forward", &replicate_forward, "(fwd) replicate a vector dynamically.");
// m.def("replicate_backward", &replicate_backward, "(bwd) replicate a vector dynamically.");
// m.def("sort", &sort, "key/value sort.");

TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
  ops.def("exclusive_cumsum(Tensor x, int dim, Tensor(a!) out) -> Tensor(a!)");
  ops.impl("exclusive_cumsum", torch::kCUDA, &exclusive_cumsum_wrapper);

  ops.def("inclusive_cumsum(Tensor x, int dim, Tensor(a!) out) -> Tensor(a!)");
  ops.impl("inclusive_cumsum", torch::kCUDA, &inclusive_cumsum_wrapper);

  ops.def("histogram(Tensor x, int num_bins) -> Tensor");
  ops.impl("histogram", torch::kCUDA, &histogram_wrapper);

  ops.def("indices(Tensor padded_bins, int block_size, int output_block_rows, int output_block_columns, Tensor(a!) out) -> Tensor(a!)");
  ops.impl("indices", torch::kCUDA, &indices_wrapper);

  ops.def("replicate_forward(Tensor x, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
  ops.impl("replicate_forward", torch::kCUDA, &replicate_forward_wrapper);

  ops.def("replicate_backward(Tensor grad, Tensor bins, Tensor(a!) out) -> Tensor(a!)");
  ops.impl("replicate_backward", torch::kCUDA, &replicate_backward_wrapper);
  
  ops.def("sort(Tensor x, int end_bit, Tensor x_out, Tensor iota_out) -> Tensor(x_out)");
  ops.impl("sort", torch::kCUDA, &sort_wrapper);

  // Register the gmm GroupedGemm operation
  ops.def("gmm(Tensor (a!) a, Tensor (b!) b, Tensor(c!) c, Tensor batch_sizes, bool trans_a, bool trans_b) -> Tensor(c!)");
  ops.impl("gmm", torch::kCUDA, &gmm);
}

REGISTER_EXTENSION(TORCH_EXTENSION_NAME)