File size: 264 Bytes
2595c46 |
1 2 3 4 5 6 7 8 9 10 11 12 13 |
#pragma once
#include <torch/all.h>
namespace megablocks {
// Public interface function for radix sorting with indices
void sort(torch::Tensor x,
int end_bit,
torch::Tensor x_out,
torch::Tensor iota_out);
} // namespace megablocks |