kernel
File size: 264 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
#pragma once

#include <torch/all.h>

namespace megablocks {

// Public interface function for radix sorting with indices
void sort(torch::Tensor x,
          int end_bit,
          torch::Tensor x_out,
          torch::Tensor iota_out);

} // namespace megablocks