kernel
File size: 283 Bytes
2595c46
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
#pragma once

#include <torch/all.h>

namespace megablocks {

// Forward declarations for the public interface functions
void exclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);
void inclusive_cumsum(torch::Tensor x, int dim, torch::Tensor out);

} // namespace megablocks