Commit
·
db03e28
0
Parent(s):
Convert causal-conv1d to a Hub kernel
Browse files- README.md +10 -0
- build.toml +49 -0
- causal-conv1d/causal_conv1d.cpp +486 -0
- causal-conv1d/causal_conv1d.h +81 -0
- causal-conv1d/causal_conv1d_bwd.cu +627 -0
- causal-conv1d/causal_conv1d_common.h +98 -0
- causal-conv1d/causal_conv1d_fwd.cu +399 -0
- causal-conv1d/causal_conv1d_update.cu +137 -0
- causal-conv1d/static_switch.h +25 -0
- flake.lock +168 -0
- flake.nix +18 -0
- tests/test_causal_conv1d.py +353 -0
- torch-ext/causal_conv1d/__init__.py +4 -0
- torch-ext/causal_conv1d/causal_conv1d_interface.py +242 -0
- torch-ext/causal_conv1d/causal_conv1d_varlen.py +86 -0
- torch-ext/causal_conv1d/cpp_functions.py +96 -0
- torch-ext/pytorch_shim.h +105 -0
- torch-ext/torch_binding.cpp +32 -0
- torch-ext/torch_binding.h +39 -0
README.md
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: bsd-3-clause
|
| 3 |
+
tags:
|
| 4 |
+
- kernel
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## causal-conv1d
|
| 8 |
+
|
| 9 |
+
Causal depthwise conv1d kernel by Tri Dao. Source: https://github.com/Dao-AILab/causal-conv1d/
|
| 10 |
+
|
build.toml
ADDED
|
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[general]
|
| 2 |
+
name = "causal_conv1d"
|
| 3 |
+
universal = false
|
| 4 |
+
|
| 5 |
+
[torch]
|
| 6 |
+
src = [
|
| 7 |
+
"torch-ext/pytorch_shim.h",
|
| 8 |
+
"torch-ext/torch_binding.cpp",
|
| 9 |
+
"torch-ext/torch_binding.h"
|
| 10 |
+
]
|
| 11 |
+
|
| 12 |
+
[kernel.causal_conv1d]
|
| 13 |
+
backend = "cuda"
|
| 14 |
+
src = [
|
| 15 |
+
"causal-conv1d/causal_conv1d_bwd.cu",
|
| 16 |
+
"causal-conv1d/causal_conv1d_common.h",
|
| 17 |
+
"causal-conv1d/causal_conv1d.cpp",
|
| 18 |
+
"causal-conv1d/causal_conv1d_fwd.cu",
|
| 19 |
+
"causal-conv1d/causal_conv1d.h",
|
| 20 |
+
"causal-conv1d/causal_conv1d_update.cu",
|
| 21 |
+
"causal-conv1d/static_switch.h",
|
| 22 |
+
]
|
| 23 |
+
include = [ "causal-conv1d" ]
|
| 24 |
+
depends = [ "torch" ]
|
| 25 |
+
|
| 26 |
+
[kernel.causal_conv1d_rocm]
|
| 27 |
+
backend = "rocm"
|
| 28 |
+
rocm-archs = [
|
| 29 |
+
"gfx906",
|
| 30 |
+
"gfx908",
|
| 31 |
+
"gfx90a",
|
| 32 |
+
"gfx940",
|
| 33 |
+
"gfx941",
|
| 34 |
+
"gfx942",
|
| 35 |
+
"gfx1030",
|
| 36 |
+
"gfx1100",
|
| 37 |
+
"gfx1101",
|
| 38 |
+
]
|
| 39 |
+
src = [
|
| 40 |
+
"causal-conv1d/causal_conv1d_bwd.cu",
|
| 41 |
+
"causal-conv1d/causal_conv1d_common.h",
|
| 42 |
+
"causal-conv1d/causal_conv1d.cpp",
|
| 43 |
+
"causal-conv1d/causal_conv1d_fwd.cu",
|
| 44 |
+
"causal-conv1d/causal_conv1d.h",
|
| 45 |
+
"causal-conv1d/causal_conv1d_update.cu",
|
| 46 |
+
"causal-conv1d/static_switch.h",
|
| 47 |
+
]
|
| 48 |
+
include = [ "causal-conv1d" ]
|
| 49 |
+
depends = [ "torch" ]
|
causal-conv1d/causal_conv1d.cpp
ADDED
|
@@ -0,0 +1,486 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <torch/all.h>
|
| 6 |
+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 7 |
+
#include <c10/core/DeviceGuard.h>
|
| 8 |
+
#else
|
| 9 |
+
#include <c10/cuda/CUDAGuard.h>
|
| 10 |
+
#endif
|
| 11 |
+
|
| 12 |
+
#include <c10/cuda/CUDAStream.h>
|
| 13 |
+
#include <vector>
|
| 14 |
+
|
| 15 |
+
#include "causal_conv1d.h"
|
| 16 |
+
|
| 17 |
+
#define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
|
| 18 |
+
|
| 19 |
+
#define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
|
| 20 |
+
if (ITYPE == at::ScalarType::Half) { \
|
| 21 |
+
using input_t = at::Half; \
|
| 22 |
+
__VA_ARGS__(); \
|
| 23 |
+
} else if (ITYPE == at::ScalarType::BFloat16) { \
|
| 24 |
+
using input_t = at::BFloat16; \
|
| 25 |
+
__VA_ARGS__(); \
|
| 26 |
+
} else if (ITYPE == at::ScalarType::Float) { \
|
| 27 |
+
using input_t = float; \
|
| 28 |
+
__VA_ARGS__(); \
|
| 29 |
+
} else { \
|
| 30 |
+
AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
|
| 31 |
+
}
|
| 32 |
+
|
| 33 |
+
#define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
|
| 34 |
+
if (WTYPE == at::ScalarType::Half) { \
|
| 35 |
+
using weight_t = at::Half; \
|
| 36 |
+
__VA_ARGS__(); \
|
| 37 |
+
} else if (WTYPE == at::ScalarType::BFloat16) { \
|
| 38 |
+
using weight_t = at::BFloat16; \
|
| 39 |
+
__VA_ARGS__(); \
|
| 40 |
+
} else if (WTYPE == at::ScalarType::Float) { \
|
| 41 |
+
using weight_t = float; \
|
| 42 |
+
__VA_ARGS__(); \
|
| 43 |
+
} else { \
|
| 44 |
+
AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
|
| 45 |
+
}
|
| 46 |
+
|
| 47 |
+
template<typename input_t, typename weight_t>
|
| 48 |
+
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 49 |
+
template <typename input_t, typename weight_t>
|
| 50 |
+
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 51 |
+
|
| 52 |
+
template<typename input_t, typename weight_t>
|
| 53 |
+
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 54 |
+
template<typename input_t, typename weight_t>
|
| 55 |
+
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 56 |
+
|
| 57 |
+
template<typename input_t, typename weight_t>
|
| 58 |
+
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 59 |
+
|
| 60 |
+
void set_conv_params_fwd(ConvParamsBase ¶ms,
|
| 61 |
+
// sizes
|
| 62 |
+
const size_t batch,
|
| 63 |
+
const size_t dim,
|
| 64 |
+
const size_t seqlen,
|
| 65 |
+
const size_t width,
|
| 66 |
+
// device pointers
|
| 67 |
+
const at::Tensor x,
|
| 68 |
+
const at::Tensor weight,
|
| 69 |
+
const at::Tensor out,
|
| 70 |
+
void* bias_ptr,
|
| 71 |
+
bool silu_activation) {
|
| 72 |
+
|
| 73 |
+
// Reset the parameters
|
| 74 |
+
memset(¶ms, 0, sizeof(params));
|
| 75 |
+
|
| 76 |
+
params.batch = batch;
|
| 77 |
+
params.dim = dim;
|
| 78 |
+
params.seqlen = seqlen;
|
| 79 |
+
params.width = width;
|
| 80 |
+
|
| 81 |
+
params.silu_activation = silu_activation;
|
| 82 |
+
|
| 83 |
+
// Set the pointers and strides.
|
| 84 |
+
params.x_ptr = x.data_ptr();
|
| 85 |
+
params.weight_ptr = weight.data_ptr();
|
| 86 |
+
params.bias_ptr = bias_ptr;
|
| 87 |
+
params.out_ptr = out.data_ptr();
|
| 88 |
+
// All stride are in elements, not bytes.
|
| 89 |
+
params.x_batch_stride = x.stride(0);
|
| 90 |
+
params.x_c_stride = x.stride(1);
|
| 91 |
+
params.x_l_stride = x.stride(-1);
|
| 92 |
+
params.weight_c_stride = weight.stride(0);
|
| 93 |
+
params.weight_width_stride = weight.stride(1);
|
| 94 |
+
params.out_batch_stride = out.stride(0);
|
| 95 |
+
params.out_c_stride = out.stride(1);
|
| 96 |
+
params.out_l_stride = out.stride(-1);
|
| 97 |
+
}
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
void set_conv_params_bwd(ConvParamsBwd ¶ms,
|
| 101 |
+
// sizes
|
| 102 |
+
const size_t batch,
|
| 103 |
+
const size_t dim,
|
| 104 |
+
const size_t seqlen,
|
| 105 |
+
const size_t width,
|
| 106 |
+
// device pointers
|
| 107 |
+
const at::Tensor x,
|
| 108 |
+
const at::Tensor weight,
|
| 109 |
+
void* bias_ptr,
|
| 110 |
+
const at::Tensor dout,
|
| 111 |
+
const at::Tensor dx,
|
| 112 |
+
const at::Tensor dweight,
|
| 113 |
+
void* dbias_ptr,
|
| 114 |
+
bool silu_activation) {
|
| 115 |
+
// Pass in "dout" instead of "out", we're not gonna use "out" at all.
|
| 116 |
+
set_conv_params_fwd(params, batch, dim, seqlen, width,
|
| 117 |
+
x, weight, dout, bias_ptr, silu_activation);
|
| 118 |
+
|
| 119 |
+
// Set the pointers and strides.
|
| 120 |
+
params.dout_ptr = dout.data_ptr();
|
| 121 |
+
params.dx_ptr = dx.data_ptr();
|
| 122 |
+
params.dweight_ptr = dweight.data_ptr();
|
| 123 |
+
params.dbias_ptr = dbias_ptr;
|
| 124 |
+
// All stride are in elements, not bytes.
|
| 125 |
+
params.dout_batch_stride = dout.stride(0);
|
| 126 |
+
params.dout_c_stride = dout.stride(1);
|
| 127 |
+
params.dout_l_stride = dout.stride(2);
|
| 128 |
+
params.dweight_c_stride = dweight.stride(0);
|
| 129 |
+
params.dweight_width_stride = dweight.stride(1);
|
| 130 |
+
params.dx_batch_stride = dx.stride(0);
|
| 131 |
+
params.dx_c_stride = dx.stride(1);
|
| 132 |
+
params.dx_l_stride = dx.stride(2);
|
| 133 |
+
}
|
| 134 |
+
|
| 135 |
+
void
|
| 136 |
+
causal_conv1d_fwd(const at::Tensor &x,
|
| 137 |
+
const at::Tensor &weight,
|
| 138 |
+
const c10::optional<at::Tensor> &bias_,
|
| 139 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
| 140 |
+
const c10::optional<at::Tensor> &initial_states_,
|
| 141 |
+
at::Tensor &out,
|
| 142 |
+
c10::optional<at::Tensor> &final_states_out_,
|
| 143 |
+
bool silu_activation) {
|
| 144 |
+
auto input_type = x.scalar_type();
|
| 145 |
+
auto weight_type = weight.scalar_type();
|
| 146 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 147 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 148 |
+
|
| 149 |
+
TORCH_CHECK(x.is_cuda());
|
| 150 |
+
TORCH_CHECK(weight.is_cuda());
|
| 151 |
+
|
| 152 |
+
const auto sizes = x.sizes();
|
| 153 |
+
const int batch_size = sizes[0];
|
| 154 |
+
const int dim = sizes[1];
|
| 155 |
+
const int seqlen = sizes[2];
|
| 156 |
+
const int width = weight.size(-1);
|
| 157 |
+
|
| 158 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 159 |
+
CHECK_SHAPE(weight, dim, width);
|
| 160 |
+
|
| 161 |
+
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
| 162 |
+
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
| 163 |
+
|
| 164 |
+
if (is_channel_last) {
|
| 165 |
+
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
| 166 |
+
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
| 167 |
+
}
|
| 168 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 169 |
+
|
| 170 |
+
if (bias_.has_value()) {
|
| 171 |
+
auto bias = bias_.value();
|
| 172 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 173 |
+
TORCH_CHECK(bias.is_cuda());
|
| 174 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
| 175 |
+
CHECK_SHAPE(bias, dim);
|
| 176 |
+
}
|
| 177 |
+
|
| 178 |
+
if (seq_idx_.has_value()) {
|
| 179 |
+
TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
|
| 180 |
+
auto seq_idx = seq_idx_.value();
|
| 181 |
+
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
| 182 |
+
TORCH_CHECK(seq_idx.is_cuda());
|
| 183 |
+
TORCH_CHECK(seq_idx.is_contiguous());
|
| 184 |
+
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
ConvParamsBase params;
|
| 188 |
+
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
| 189 |
+
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 190 |
+
silu_activation);
|
| 191 |
+
|
| 192 |
+
if (seq_idx_.has_value()) {
|
| 193 |
+
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
| 194 |
+
} else {
|
| 195 |
+
params.seq_idx_ptr = nullptr;
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
if (initial_states_.has_value()) {
|
| 199 |
+
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
| 200 |
+
auto initial_states = initial_states_.value();
|
| 201 |
+
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
| 202 |
+
TORCH_CHECK(initial_states.is_cuda());
|
| 203 |
+
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
| 204 |
+
TORCH_CHECK(initial_states.stride(1) == 1);
|
| 205 |
+
params.initial_states_ptr = initial_states.data_ptr();
|
| 206 |
+
params.initial_states_batch_stride = initial_states.stride(0);
|
| 207 |
+
params.initial_states_c_stride = initial_states.stride(1);
|
| 208 |
+
params.initial_states_l_stride = initial_states.stride(2);
|
| 209 |
+
} else {
|
| 210 |
+
params.initial_states_ptr = nullptr;
|
| 211 |
+
}
|
| 212 |
+
|
| 213 |
+
if (final_states_out_.has_value()) {
|
| 214 |
+
TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
|
| 215 |
+
auto final_states = final_states_out_.value();
|
| 216 |
+
TORCH_CHECK(final_states.scalar_type() == input_type);
|
| 217 |
+
TORCH_CHECK(final_states.is_cuda());
|
| 218 |
+
CHECK_SHAPE(final_states, batch_size, dim, width - 1);
|
| 219 |
+
TORCH_CHECK(final_states.stride(1) == 1);
|
| 220 |
+
params.final_states_ptr = final_states.data_ptr();
|
| 221 |
+
params.final_states_batch_stride = final_states.stride(0);
|
| 222 |
+
params.final_states_c_stride = final_states.stride(1);
|
| 223 |
+
params.final_states_l_stride = final_states.stride(2);
|
| 224 |
+
} else {
|
| 225 |
+
params.final_states_ptr = nullptr;
|
| 226 |
+
}
|
| 227 |
+
|
| 228 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 229 |
+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 230 |
+
c10::DeviceGuard device_guard(x.device());
|
| 231 |
+
#else
|
| 232 |
+
at::cuda::CUDAGuard device_guard{x.device()};
|
| 233 |
+
#endif
|
| 234 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 235 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
|
| 236 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
|
| 237 |
+
if (!is_channel_last) {
|
| 238 |
+
causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
|
| 239 |
+
} else {
|
| 240 |
+
causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
|
| 241 |
+
}
|
| 242 |
+
});
|
| 243 |
+
});
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
void
|
| 247 |
+
causal_conv1d_bwd(const at::Tensor &x,
|
| 248 |
+
const at::Tensor &weight,
|
| 249 |
+
const c10::optional<at::Tensor> &bias_,
|
| 250 |
+
at::Tensor &dout,
|
| 251 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
| 252 |
+
const c10::optional<at::Tensor> &initial_states_,
|
| 253 |
+
const c10::optional<at::Tensor> &dfinal_states_,
|
| 254 |
+
at::Tensor &dx,
|
| 255 |
+
at::Tensor &dweight,
|
| 256 |
+
c10::optional<at::Tensor> &dbias_,
|
| 257 |
+
c10::optional<at::Tensor> &dinitial_states_,
|
| 258 |
+
bool silu_activation) {
|
| 259 |
+
auto input_type = x.scalar_type();
|
| 260 |
+
auto weight_type = weight.scalar_type();
|
| 261 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 262 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 263 |
+
|
| 264 |
+
TORCH_CHECK(x.is_cuda());
|
| 265 |
+
TORCH_CHECK(weight.is_cuda());
|
| 266 |
+
TORCH_CHECK(dout.is_cuda());
|
| 267 |
+
TORCH_CHECK(bias_.has_value() == dbias_.has_value());
|
| 268 |
+
|
| 269 |
+
const auto sizes = x.sizes();
|
| 270 |
+
const int batch_size = sizes[0];
|
| 271 |
+
const int dim = sizes[1];
|
| 272 |
+
const int seqlen = sizes[2];
|
| 273 |
+
const int width = weight.size(-1);
|
| 274 |
+
|
| 275 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 276 |
+
|
| 277 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 278 |
+
CHECK_SHAPE(weight, dim, width);
|
| 279 |
+
CHECK_SHAPE(dout, batch_size, dim, seqlen);
|
| 280 |
+
|
| 281 |
+
TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
|
| 282 |
+
const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
|
| 283 |
+
if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
|
| 284 |
+
if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
|
| 285 |
+
|
| 286 |
+
if (is_channel_last) {
|
| 287 |
+
TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
|
| 288 |
+
TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
|
| 289 |
+
TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
|
| 290 |
+
}
|
| 291 |
+
|
| 292 |
+
if (bias_.has_value()) {
|
| 293 |
+
auto bias = bias_.value();
|
| 294 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 295 |
+
TORCH_CHECK(bias.is_cuda());
|
| 296 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
| 297 |
+
CHECK_SHAPE(bias, dim);
|
| 298 |
+
}
|
| 299 |
+
|
| 300 |
+
if (seq_idx_.has_value()) {
|
| 301 |
+
TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
|
| 302 |
+
auto seq_idx = seq_idx_.value();
|
| 303 |
+
TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
|
| 304 |
+
TORCH_CHECK(seq_idx.is_cuda());
|
| 305 |
+
TORCH_CHECK(seq_idx.is_contiguous());
|
| 306 |
+
CHECK_SHAPE(seq_idx, batch_size, seqlen);
|
| 307 |
+
}
|
| 308 |
+
|
| 309 |
+
TORCH_CHECK(dx.scalar_type() == input_type);
|
| 310 |
+
TORCH_CHECK(dx.is_cuda());
|
| 311 |
+
CHECK_SHAPE(dx, batch_size, dim, seqlen);
|
| 312 |
+
if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
|
| 313 |
+
if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
|
| 314 |
+
|
| 315 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 316 |
+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 317 |
+
c10::Device device = x.device();
|
| 318 |
+
c10::DeviceGuard device_guard(device);
|
| 319 |
+
#else
|
| 320 |
+
at::cuda::CUDAGuard device_guard{x.device()};
|
| 321 |
+
#endif
|
| 322 |
+
ConvParamsBwd params;
|
| 323 |
+
set_conv_params_bwd(params, batch_size, dim, seqlen, width,
|
| 324 |
+
x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 325 |
+
dout, dx, dweight, bias_.has_value() ? dbias_.value().data_ptr() : nullptr,
|
| 326 |
+
silu_activation);
|
| 327 |
+
|
| 328 |
+
if (seq_idx_.has_value()) {
|
| 329 |
+
params.seq_idx_ptr = seq_idx_.value().data_ptr();
|
| 330 |
+
} else {
|
| 331 |
+
params.seq_idx_ptr = nullptr;
|
| 332 |
+
}
|
| 333 |
+
|
| 334 |
+
if (initial_states_.has_value()) {
|
| 335 |
+
TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
|
| 336 |
+
auto initial_states = initial_states_.value();
|
| 337 |
+
TORCH_CHECK(initial_states.scalar_type() == input_type);
|
| 338 |
+
TORCH_CHECK(initial_states.is_cuda());
|
| 339 |
+
CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
|
| 340 |
+
TORCH_CHECK(initial_states.stride(1) == 1);
|
| 341 |
+
params.initial_states_ptr = initial_states.data_ptr();
|
| 342 |
+
params.initial_states_batch_stride = initial_states.stride(0);
|
| 343 |
+
params.initial_states_c_stride = initial_states.stride(1);
|
| 344 |
+
params.initial_states_l_stride = initial_states.stride(2);
|
| 345 |
+
} else {
|
| 346 |
+
params.initial_states_ptr = nullptr;
|
| 347 |
+
}
|
| 348 |
+
|
| 349 |
+
if (dfinal_states_.has_value()) {
|
| 350 |
+
TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
|
| 351 |
+
auto dfinal_states = dfinal_states_.value();
|
| 352 |
+
TORCH_CHECK(dfinal_states.scalar_type() == input_type);
|
| 353 |
+
TORCH_CHECK(dfinal_states.is_cuda());
|
| 354 |
+
CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
|
| 355 |
+
params.dfinal_states_ptr = dfinal_states.data_ptr();
|
| 356 |
+
params.dfinal_states_batch_stride = dfinal_states.stride(0);
|
| 357 |
+
params.dfinal_states_c_stride = dfinal_states.stride(1);
|
| 358 |
+
params.dfinal_states_l_stride = dfinal_states.stride(2);
|
| 359 |
+
} else {
|
| 360 |
+
params.dfinal_states_ptr = nullptr;
|
| 361 |
+
}
|
| 362 |
+
|
| 363 |
+
if (dinitial_states_.has_value()) {
|
| 364 |
+
at::Tensor dinitial_states = dinitial_states_.value();
|
| 365 |
+
TORCH_CHECK(dinitial_states.stride(1) == 1);
|
| 366 |
+
params.dinitial_states_ptr = dinitial_states.data_ptr();
|
| 367 |
+
params.dinitial_states_batch_stride = dinitial_states.stride(0);
|
| 368 |
+
params.dinitial_states_c_stride = dinitial_states.stride(1);
|
| 369 |
+
params.dinitial_states_l_stride = dinitial_states.stride(2);
|
| 370 |
+
} else {
|
| 371 |
+
params.dinitial_states_ptr = nullptr;
|
| 372 |
+
}
|
| 373 |
+
|
| 374 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 375 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
|
| 376 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
|
| 377 |
+
if (!is_channel_last) {
|
| 378 |
+
causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
|
| 379 |
+
} else {
|
| 380 |
+
causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
|
| 381 |
+
}
|
| 382 |
+
});
|
| 383 |
+
});
|
| 384 |
+
}
|
| 385 |
+
|
| 386 |
+
void
|
| 387 |
+
causal_conv1d_update(const at::Tensor &x,
|
| 388 |
+
const at::Tensor &conv_state,
|
| 389 |
+
const at::Tensor &weight,
|
| 390 |
+
const c10::optional<at::Tensor> &bias_,
|
| 391 |
+
at::Tensor &out,
|
| 392 |
+
bool silu_activation,
|
| 393 |
+
const c10::optional<at::Tensor> &cache_seqlens_,
|
| 394 |
+
const c10::optional<at::Tensor> &conv_state_indices_
|
| 395 |
+
) {
|
| 396 |
+
auto input_type = x.scalar_type();
|
| 397 |
+
auto weight_type = weight.scalar_type();
|
| 398 |
+
TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
|
| 399 |
+
TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
|
| 400 |
+
TORCH_CHECK(conv_state.scalar_type() == input_type);
|
| 401 |
+
|
| 402 |
+
TORCH_CHECK(x.is_cuda());
|
| 403 |
+
TORCH_CHECK(conv_state.is_cuda());
|
| 404 |
+
TORCH_CHECK(weight.is_cuda());
|
| 405 |
+
|
| 406 |
+
const auto sizes = x.sizes();
|
| 407 |
+
const int batch_size = sizes[0];
|
| 408 |
+
const int dim = sizes[1];
|
| 409 |
+
const int seqlen = sizes[2];
|
| 410 |
+
const int width = weight.size(-1);
|
| 411 |
+
const int conv_state_len = conv_state.size(2);
|
| 412 |
+
TORCH_CHECK(conv_state_len >= width - 1);
|
| 413 |
+
|
| 414 |
+
CHECK_SHAPE(x, batch_size, dim, seqlen);
|
| 415 |
+
CHECK_SHAPE(weight, dim, width);
|
| 416 |
+
|
| 417 |
+
TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
|
| 418 |
+
|
| 419 |
+
if (bias_.has_value()) {
|
| 420 |
+
auto bias = bias_.value();
|
| 421 |
+
TORCH_CHECK(bias.scalar_type() == weight_type);
|
| 422 |
+
TORCH_CHECK(bias.is_cuda());
|
| 423 |
+
TORCH_CHECK(bias.stride(-1) == 1);
|
| 424 |
+
CHECK_SHAPE(bias, dim);
|
| 425 |
+
}
|
| 426 |
+
|
| 427 |
+
ConvParamsBase params;
|
| 428 |
+
set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
|
| 429 |
+
bias_.has_value() ? bias_.value().data_ptr() : nullptr,
|
| 430 |
+
silu_activation);
|
| 431 |
+
params.conv_state_ptr = conv_state.data_ptr();
|
| 432 |
+
params.conv_state_len = conv_state_len;
|
| 433 |
+
// All stride are in elements, not bytes.
|
| 434 |
+
params.conv_state_batch_stride = conv_state.stride(0);
|
| 435 |
+
params.conv_state_c_stride = conv_state.stride(1);
|
| 436 |
+
params.conv_state_l_stride = conv_state.stride(2);
|
| 437 |
+
|
| 438 |
+
if (conv_state_indices_.has_value()) {
|
| 439 |
+
auto conv_state_indices = conv_state_indices_.value();
|
| 440 |
+
TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
|
| 441 |
+
TORCH_CHECK(conv_state_indices.is_cuda());
|
| 442 |
+
TORCH_CHECK(conv_state_indices.stride(0) == 1)
|
| 443 |
+
CHECK_SHAPE(conv_state_indices, batch_size);
|
| 444 |
+
|
| 445 |
+
int conv_state_entries = conv_state.size(0);
|
| 446 |
+
CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
|
| 447 |
+
|
| 448 |
+
params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
|
| 449 |
+
} else {
|
| 450 |
+
CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
|
| 451 |
+
params.conv_state_indices_ptr = nullptr;
|
| 452 |
+
}
|
| 453 |
+
|
| 454 |
+
if (cache_seqlens_.has_value()) {
|
| 455 |
+
auto cache_seqlens = cache_seqlens_.value();
|
| 456 |
+
TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
|
| 457 |
+
TORCH_CHECK(cache_seqlens.is_cuda());
|
| 458 |
+
TORCH_CHECK(cache_seqlens.stride(-1) == 1);
|
| 459 |
+
CHECK_SHAPE(cache_seqlens, batch_size);
|
| 460 |
+
params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
|
| 461 |
+
} else {
|
| 462 |
+
params.cache_seqlens = nullptr;
|
| 463 |
+
}
|
| 464 |
+
|
| 465 |
+
// Otherwise the kernel will be launched from cuda:0 device
|
| 466 |
+
#if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
|
| 467 |
+
c10::Device device = x.device();
|
| 468 |
+
c10::DeviceGuard device_guard(device);
|
| 469 |
+
#else
|
| 470 |
+
at::cuda::CUDAGuard device_guard{x.device()};
|
| 471 |
+
#endif
|
| 472 |
+
auto stream = at::cuda::getCurrentCUDAStream().stream();
|
| 473 |
+
DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
|
| 474 |
+
DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
|
| 475 |
+
causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
|
| 476 |
+
});
|
| 477 |
+
});
|
| 478 |
+
}
|
| 479 |
+
|
| 480 |
+
/*
|
| 481 |
+
PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
|
| 482 |
+
m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
|
| 483 |
+
m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
|
| 484 |
+
m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
|
| 485 |
+
}
|
| 486 |
+
*/
|
causal-conv1d/causal_conv1d.h
ADDED
|
@@ -0,0 +1,81 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 8 |
+
|
| 9 |
+
struct ConvParamsBase {
|
| 10 |
+
using index_t = uint32_t;
|
| 11 |
+
|
| 12 |
+
int batch, dim, seqlen, width;
|
| 13 |
+
bool silu_activation;
|
| 14 |
+
|
| 15 |
+
index_t x_batch_stride;
|
| 16 |
+
index_t x_c_stride;
|
| 17 |
+
index_t x_l_stride;
|
| 18 |
+
index_t weight_c_stride;
|
| 19 |
+
index_t weight_width_stride;
|
| 20 |
+
index_t out_batch_stride;
|
| 21 |
+
index_t out_c_stride;
|
| 22 |
+
index_t out_l_stride;
|
| 23 |
+
|
| 24 |
+
int conv_state_len;
|
| 25 |
+
index_t conv_state_batch_stride;
|
| 26 |
+
index_t conv_state_c_stride;
|
| 27 |
+
index_t conv_state_l_stride;
|
| 28 |
+
|
| 29 |
+
// Common data pointers.
|
| 30 |
+
void *__restrict__ x_ptr;
|
| 31 |
+
void *__restrict__ weight_ptr;
|
| 32 |
+
void *__restrict__ bias_ptr;
|
| 33 |
+
void *__restrict__ out_ptr;
|
| 34 |
+
|
| 35 |
+
void *__restrict__ conv_state_ptr;
|
| 36 |
+
int32_t *__restrict__ cache_seqlens;
|
| 37 |
+
|
| 38 |
+
// Only used if the elements of the batch are gathered from a larger buffer,
|
| 39 |
+
// which may happen for continuous batching.
|
| 40 |
+
int32_t *__restrict__ conv_state_indices_ptr;
|
| 41 |
+
|
| 42 |
+
void *__restrict__ seq_idx_ptr;
|
| 43 |
+
|
| 44 |
+
// No __restrict__ since initial_states could be the same as final_states.
|
| 45 |
+
void * initial_states_ptr;
|
| 46 |
+
index_t initial_states_batch_stride;
|
| 47 |
+
index_t initial_states_l_stride;
|
| 48 |
+
index_t initial_states_c_stride;
|
| 49 |
+
|
| 50 |
+
void * final_states_ptr;
|
| 51 |
+
index_t final_states_batch_stride;
|
| 52 |
+
index_t final_states_l_stride;
|
| 53 |
+
index_t final_states_c_stride;
|
| 54 |
+
};
|
| 55 |
+
|
| 56 |
+
struct ConvParamsBwd: public ConvParamsBase {
|
| 57 |
+
index_t dx_batch_stride;
|
| 58 |
+
index_t dx_c_stride;
|
| 59 |
+
index_t dx_l_stride;
|
| 60 |
+
index_t dweight_c_stride;
|
| 61 |
+
index_t dweight_width_stride;
|
| 62 |
+
index_t dout_batch_stride;
|
| 63 |
+
index_t dout_c_stride;
|
| 64 |
+
index_t dout_l_stride;
|
| 65 |
+
|
| 66 |
+
// Common data pointers.
|
| 67 |
+
void *__restrict__ dx_ptr;
|
| 68 |
+
void *__restrict__ dweight_ptr;
|
| 69 |
+
void *__restrict__ dbias_ptr;
|
| 70 |
+
void *__restrict__ dout_ptr;
|
| 71 |
+
|
| 72 |
+
void * dinitial_states_ptr;
|
| 73 |
+
index_t dinitial_states_batch_stride;
|
| 74 |
+
index_t dinitial_states_l_stride;
|
| 75 |
+
index_t dinitial_states_c_stride;
|
| 76 |
+
|
| 77 |
+
void * dfinal_states_ptr;
|
| 78 |
+
index_t dfinal_states_batch_stride;
|
| 79 |
+
index_t dfinal_states_l_stride;
|
| 80 |
+
index_t dfinal_states_c_stride;
|
| 81 |
+
};
|
causal-conv1d/causal_conv1d_bwd.cu
ADDED
|
@@ -0,0 +1,627 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
+
|
| 9 |
+
#ifndef USE_ROCM
|
| 10 |
+
#include <cub/block/block_load.cuh>
|
| 11 |
+
#include <cub/block/block_store.cuh>
|
| 12 |
+
#include <cub/block/block_reduce.cuh>
|
| 13 |
+
#else
|
| 14 |
+
#include <hipcub/hipcub.hpp>
|
| 15 |
+
namespace cub = hipcub;
|
| 16 |
+
#endif
|
| 17 |
+
|
| 18 |
+
#include "causal_conv1d.h"
|
| 19 |
+
#include "causal_conv1d_common.h"
|
| 20 |
+
#include "static_switch.h"
|
| 21 |
+
|
| 22 |
+
template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 23 |
+
struct Causal_conv1d_bwd_kernel_traits {
|
| 24 |
+
using input_t = input_t_;
|
| 25 |
+
using weight_t = weight_t_;
|
| 26 |
+
static constexpr int kNThreads = kNThreads_;
|
| 27 |
+
static constexpr int kWidth = kWidth_;
|
| 28 |
+
static constexpr bool kSiluAct = kSiluAct_;
|
| 29 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 30 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 31 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 32 |
+
static_assert(kWidth <= kNElts);
|
| 33 |
+
// It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
|
| 34 |
+
// (since then we'd have 8 values of float, and each round we can exchange 4 floats).
|
| 35 |
+
static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
|
| 36 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 37 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 38 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 39 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
| 40 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 41 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
| 42 |
+
using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
|
| 43 |
+
static constexpr int kSmemIOSize = kIsVecLoad
|
| 44 |
+
? 0
|
| 45 |
+
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
| 46 |
+
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
|
| 47 |
+
static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
|
| 48 |
+
int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
|
| 49 |
+
};
|
| 50 |
+
|
| 51 |
+
template<typename Ktraits>
|
| 52 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 53 |
+
void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
|
| 54 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 55 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 56 |
+
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
| 57 |
+
static constexpr int kNElts = Ktraits::kNElts;
|
| 58 |
+
constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
|
| 59 |
+
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
| 60 |
+
using input_t = typename Ktraits::input_t;
|
| 61 |
+
using vec_t = typename Ktraits::vec_t;
|
| 62 |
+
using weight_t = typename Ktraits::weight_t;
|
| 63 |
+
|
| 64 |
+
// Shared memory.
|
| 65 |
+
extern __shared__ char smem_[];
|
| 66 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 67 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
| 68 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 69 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
| 70 |
+
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
| 71 |
+
vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
|
| 72 |
+
auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
|
| 73 |
+
|
| 74 |
+
const int tidx = threadIdx.x;
|
| 75 |
+
const int batch_id = blockIdx.x;
|
| 76 |
+
const int dim_id = blockIdx.y;
|
| 77 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 78 |
+
+ dim_id * params.x_c_stride;
|
| 79 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
|
| 80 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 81 |
+
+ dim_id * params.dout_c_stride;
|
| 82 |
+
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
| 83 |
+
+ dim_id * params.dx_c_stride;
|
| 84 |
+
float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
|
| 85 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
|
| 86 |
+
|
| 87 |
+
// Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
|
| 88 |
+
if (tidx == 0) {
|
| 89 |
+
if constexpr (!kSiluAct) {
|
| 90 |
+
input_t zeros[kNElts] = {0};
|
| 91 |
+
smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
|
| 92 |
+
} else {
|
| 93 |
+
float zeros[kNElts] = {0};
|
| 94 |
+
#pragma unroll
|
| 95 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 96 |
+
smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
|
| 97 |
+
}
|
| 98 |
+
}
|
| 99 |
+
}
|
| 100 |
+
|
| 101 |
+
float weight_vals[kWidth];
|
| 102 |
+
#pragma unroll
|
| 103 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
|
| 104 |
+
|
| 105 |
+
float dweight_vals[kWidth] = {0};
|
| 106 |
+
float dbias_val = 0;
|
| 107 |
+
|
| 108 |
+
constexpr int kChunkSize = kNThreads * kNElts;
|
| 109 |
+
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
| 110 |
+
x += (n_chunks - 1) * kChunkSize;
|
| 111 |
+
dout += (n_chunks - 1) * kChunkSize;
|
| 112 |
+
dx += (n_chunks - 1) * kChunkSize;
|
| 113 |
+
for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
|
| 114 |
+
input_t x_vals_load[2 * kNElts] = {0};
|
| 115 |
+
input_t dout_vals_load[2 * kNElts] = {0};
|
| 116 |
+
if constexpr(kIsVecLoad) {
|
| 117 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 118 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 119 |
+
} else {
|
| 120 |
+
__syncthreads();
|
| 121 |
+
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
| 122 |
+
__syncthreads();
|
| 123 |
+
typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
|
| 124 |
+
}
|
| 125 |
+
float dout_vals[2 * kNElts], x_vals[2 * kNElts];
|
| 126 |
+
if constexpr (!kSiluAct) {
|
| 127 |
+
__syncthreads();
|
| 128 |
+
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
| 129 |
+
// the first elements of the next chunk.
|
| 130 |
+
if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
| 131 |
+
__syncthreads();
|
| 132 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
|
| 133 |
+
__syncthreads();
|
| 134 |
+
// Now thread 0 can write the first elements of the current chunk.
|
| 135 |
+
if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
|
| 136 |
+
#pragma unroll
|
| 137 |
+
for (int i = 0; i < 2 * kNElts; ++i) {
|
| 138 |
+
dout_vals[i] = float(dout_vals_load[i]);
|
| 139 |
+
x_vals[i] = float(x_vals_load[i]);
|
| 140 |
+
}
|
| 141 |
+
} else {
|
| 142 |
+
if (tidx == 0 && chunk > 0) {
|
| 143 |
+
if constexpr(kIsVecLoad) {
|
| 144 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
|
| 145 |
+
} else {
|
| 146 |
+
#pragma unroll
|
| 147 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 148 |
+
if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
|
| 149 |
+
}
|
| 150 |
+
}
|
| 151 |
+
}
|
| 152 |
+
__syncthreads();
|
| 153 |
+
smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
|
| 154 |
+
__syncthreads();
|
| 155 |
+
if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
|
| 156 |
+
#pragma unroll
|
| 157 |
+
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
| 158 |
+
// Recompute the output
|
| 159 |
+
#pragma unroll
|
| 160 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 161 |
+
float out_val = bias_val;
|
| 162 |
+
#pragma unroll
|
| 163 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 164 |
+
out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
| 165 |
+
}
|
| 166 |
+
float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
|
| 167 |
+
dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
|
| 168 |
+
* (1.0f + out_val * (1.0f - out_sigmoid_val));
|
| 169 |
+
}
|
| 170 |
+
// Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
|
| 171 |
+
// if input_t is 16 bits (since then we'd have 8 values of float)
|
| 172 |
+
__syncthreads();
|
| 173 |
+
// Thread 0 don't write yet, so that thread kNThreads - 1 can read
|
| 174 |
+
// the first elements of the next chunk.
|
| 175 |
+
if (tidx > 0) {
|
| 176 |
+
#pragma unroll
|
| 177 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 178 |
+
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
| 179 |
+
}
|
| 180 |
+
}
|
| 181 |
+
__syncthreads();
|
| 182 |
+
#pragma unroll
|
| 183 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 184 |
+
reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
|
| 185 |
+
= smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
|
| 186 |
+
}
|
| 187 |
+
__syncthreads();
|
| 188 |
+
// Now thread 0 can write the first elements of the current chunk.
|
| 189 |
+
if (tidx == 0) {
|
| 190 |
+
#pragma unroll
|
| 191 |
+
for (int r = 0; r < kNExchangeRounds; ++r) {
|
| 192 |
+
smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
|
| 193 |
+
}
|
| 194 |
+
}
|
| 195 |
+
}
|
| 196 |
+
dout -= kChunkSize;
|
| 197 |
+
x -= kChunkSize;
|
| 198 |
+
|
| 199 |
+
#pragma unroll
|
| 200 |
+
for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
|
| 201 |
+
|
| 202 |
+
float dx_vals[kNElts] = {0};
|
| 203 |
+
#pragma unroll
|
| 204 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 205 |
+
#pragma unroll
|
| 206 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 207 |
+
dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
|
| 208 |
+
}
|
| 209 |
+
}
|
| 210 |
+
|
| 211 |
+
input_t dx_vals_store[kNElts];
|
| 212 |
+
#pragma unroll
|
| 213 |
+
for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
|
| 214 |
+
if constexpr(kIsVecLoad) {
|
| 215 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 216 |
+
} else {
|
| 217 |
+
typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
|
| 218 |
+
}
|
| 219 |
+
dx -= kChunkSize;
|
| 220 |
+
|
| 221 |
+
#pragma unroll
|
| 222 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 223 |
+
#pragma unroll
|
| 224 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 225 |
+
dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
|
| 226 |
+
}
|
| 227 |
+
}
|
| 228 |
+
}
|
| 229 |
+
|
| 230 |
+
#pragma unroll
|
| 231 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 232 |
+
__syncthreads();
|
| 233 |
+
dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
|
| 234 |
+
if (tidx == 0) {
|
| 235 |
+
atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
|
| 236 |
+
}
|
| 237 |
+
}
|
| 238 |
+
if (params.bias_ptr != nullptr) {
|
| 239 |
+
__syncthreads();
|
| 240 |
+
dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
|
| 241 |
+
if (tidx == 0) {
|
| 242 |
+
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
|
| 243 |
+
}
|
| 244 |
+
}
|
| 245 |
+
}
|
| 246 |
+
|
| 247 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 248 |
+
void causal_conv1d_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 249 |
+
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
| 250 |
+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
| 251 |
+
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
| 252 |
+
using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
|
| 253 |
+
constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 254 |
+
dim3 grid(params.batch, params.dim);
|
| 255 |
+
auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
|
| 256 |
+
|
| 257 |
+
if (kSmemSize >= 48 * 1024) {
|
| 258 |
+
#ifndef USE_ROCM
|
| 259 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 260 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 261 |
+
#else
|
| 262 |
+
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
| 263 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 264 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 265 |
+
std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
| 266 |
+
#endif
|
| 267 |
+
}
|
| 268 |
+
|
| 269 |
+
|
| 270 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 271 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 272 |
+
});
|
| 273 |
+
});
|
| 274 |
+
}
|
| 275 |
+
|
| 276 |
+
template<typename input_t, typename weight_t>
|
| 277 |
+
void causal_conv1d_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 278 |
+
if (params.width == 2) {
|
| 279 |
+
causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 280 |
+
} else if (params.width == 3) {
|
| 281 |
+
causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 282 |
+
} else if (params.width == 4) {
|
| 283 |
+
causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 284 |
+
}
|
| 285 |
+
}
|
| 286 |
+
|
| 287 |
+
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 288 |
+
struct Causal_conv1d_channellast_bwd_kernel_traits {
|
| 289 |
+
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
| 290 |
+
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
| 291 |
+
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
| 292 |
+
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
| 293 |
+
using input_t = input_t_;
|
| 294 |
+
using weight_t = weight_t_;
|
| 295 |
+
static constexpr bool kSiluAct = kSiluAct_;
|
| 296 |
+
static constexpr int kNThreads = kNThreads_;
|
| 297 |
+
static_assert(kNThreads % 32 == 0);
|
| 298 |
+
static constexpr int kNWarps = kNThreads / 32;
|
| 299 |
+
static constexpr int kWidth = kWidth_;
|
| 300 |
+
static constexpr int kChunkSizeL = kChunkSizeL_;
|
| 301 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 302 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 303 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 304 |
+
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
| 305 |
+
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
| 306 |
+
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
| 307 |
+
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
| 308 |
+
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
| 309 |
+
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
| 310 |
+
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
| 311 |
+
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
| 312 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 313 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 314 |
+
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 315 |
+
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 316 |
+
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 317 |
+
// sizeof(typename BlockStoreT::TempStorage)});
|
| 318 |
+
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
| 319 |
+
};
|
| 320 |
+
|
| 321 |
+
template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
|
| 322 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 323 |
+
void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
|
| 324 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 325 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 326 |
+
constexpr bool kSiluAct = Ktraits::kSiluAct;
|
| 327 |
+
constexpr int kNElts = Ktraits::kNElts;
|
| 328 |
+
constexpr int kNWarp = Ktraits::kNWarps;
|
| 329 |
+
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
| 330 |
+
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
| 331 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 332 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 333 |
+
using input_t = typename Ktraits::input_t;
|
| 334 |
+
using vec_t = typename Ktraits::vec_t;
|
| 335 |
+
using weight_t = typename Ktraits::weight_t;
|
| 336 |
+
|
| 337 |
+
// Shared memory.
|
| 338 |
+
__shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
| 339 |
+
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
|
| 340 |
+
|
| 341 |
+
const int batch_id = blockIdx.x;
|
| 342 |
+
const int chunk_l_id = blockIdx.y;
|
| 343 |
+
const int chunk_c_id = blockIdx.z;
|
| 344 |
+
const int tid = threadIdx.x;
|
| 345 |
+
const int l_idx = tid / kNThreadsPerC;
|
| 346 |
+
const int c_idx = tid % kNThreadsPerC;
|
| 347 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 348 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 349 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
| 350 |
+
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
| 351 |
+
input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
|
| 352 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 353 |
+
input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
|
| 354 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 355 |
+
float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
|
| 356 |
+
+ chunk_c_id * kChunkSizeC * params.dweight_c_stride;
|
| 357 |
+
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
| 358 |
+
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
| 359 |
+
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 360 |
+
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 361 |
+
input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 362 |
+
: reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 363 |
+
input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
|
| 364 |
+
: reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
|
| 365 |
+
|
| 366 |
+
#pragma unroll
|
| 367 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 368 |
+
input_t dout_vals_load[kNElts] = {0};
|
| 369 |
+
input_t x_vals_load[kNElts] = {0};
|
| 370 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 371 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 372 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
|
| 373 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
| 374 |
+
}
|
| 375 |
+
reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
| 376 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 377 |
+
}
|
| 378 |
+
// Load the elements from the previous chunk or next chunk that are needed for convolution.
|
| 379 |
+
if (l_idx < kWidth - 1) {
|
| 380 |
+
input_t dout_vals_load[kNElts] = {0};
|
| 381 |
+
input_t x_vals_load[kNElts] = {0};
|
| 382 |
+
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
| 383 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 384 |
+
reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
|
| 385 |
+
}
|
| 386 |
+
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
| 387 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
| 388 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 389 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
| 390 |
+
} else if (initial_states != nullptr
|
| 391 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
| 392 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 393 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
| 394 |
+
}
|
| 395 |
+
reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
|
| 396 |
+
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 397 |
+
}
|
| 398 |
+
// Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
|
| 399 |
+
if constexpr (kSiluAct) {
|
| 400 |
+
if (l_idx < kWidth - 1) {
|
| 401 |
+
input_t x_vals_load[kNElts] = {0};
|
| 402 |
+
if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
|
| 403 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 404 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
|
| 405 |
+
}
|
| 406 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 407 |
+
}
|
| 408 |
+
}
|
| 409 |
+
|
| 410 |
+
__syncthreads();
|
| 411 |
+
|
| 412 |
+
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
| 413 |
+
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
| 414 |
+
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
| 415 |
+
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
| 416 |
+
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
| 417 |
+
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
| 418 |
+
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
| 419 |
+
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
| 420 |
+
static_assert(kNThreadsPerRow <= 32);
|
| 421 |
+
|
| 422 |
+
const int row_idx = tid / kNThreadsPerRow;
|
| 423 |
+
const int col_idx = tid % kNThreadsPerRow;
|
| 424 |
+
|
| 425 |
+
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
| 426 |
+
float weight_vals[kWidth] = {0};
|
| 427 |
+
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 428 |
+
#pragma unroll
|
| 429 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 430 |
+
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
| 431 |
+
}
|
| 432 |
+
}
|
| 433 |
+
float dout_vals[kLPerThread + kWidth - 1];
|
| 434 |
+
float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
|
| 435 |
+
#pragma unroll
|
| 436 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 437 |
+
dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
|
| 438 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 439 |
+
}
|
| 440 |
+
|
| 441 |
+
int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
|
| 442 |
+
if constexpr (kHasSeqIdx) {
|
| 443 |
+
#pragma unroll
|
| 444 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
| 445 |
+
const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
|
| 446 |
+
seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
| 447 |
+
}
|
| 448 |
+
}
|
| 449 |
+
|
| 450 |
+
if constexpr (kSiluAct) { // Recompute the output
|
| 451 |
+
#pragma unroll
|
| 452 |
+
for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
|
| 453 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 454 |
+
}
|
| 455 |
+
#pragma unroll
|
| 456 |
+
for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
|
| 457 |
+
float out_val = bias_val;
|
| 458 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 459 |
+
#pragma unroll
|
| 460 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 461 |
+
if constexpr (!kHasSeqIdx) {
|
| 462 |
+
out_val += weight_vals[w] * x_vals[i + w];
|
| 463 |
+
} else {
|
| 464 |
+
out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
| 465 |
+
}
|
| 466 |
+
}
|
| 467 |
+
float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
|
| 468 |
+
dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
|
| 469 |
+
}
|
| 470 |
+
}
|
| 471 |
+
|
| 472 |
+
float dweight_vals[kWidth] = {0};
|
| 473 |
+
SumOp<float> sum_op;
|
| 474 |
+
#pragma unroll
|
| 475 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 476 |
+
#pragma unroll
|
| 477 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
| 478 |
+
if constexpr (!kHasSeqIdx) {
|
| 479 |
+
dweight_vals[w] += x_vals[i + w] * dout_vals[i];
|
| 480 |
+
} else {
|
| 481 |
+
dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
|
| 482 |
+
}
|
| 483 |
+
}
|
| 484 |
+
dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
|
| 485 |
+
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 486 |
+
atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
|
| 487 |
+
}
|
| 488 |
+
}
|
| 489 |
+
|
| 490 |
+
if (params.bias_ptr != nullptr) {
|
| 491 |
+
float dbias_val = 0.f;
|
| 492 |
+
for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
|
| 493 |
+
dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
|
| 494 |
+
if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 495 |
+
atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
|
| 496 |
+
}
|
| 497 |
+
}
|
| 498 |
+
|
| 499 |
+
float dx_vals[kLPerThread] = {0};
|
| 500 |
+
#pragma unroll
|
| 501 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
| 502 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 503 |
+
#pragma unroll
|
| 504 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 505 |
+
if constexpr (!kHasSeqIdx) {
|
| 506 |
+
dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
|
| 507 |
+
} else {
|
| 508 |
+
dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
|
| 509 |
+
}
|
| 510 |
+
}
|
| 511 |
+
// if (dfinal_states != nullptr) {
|
| 512 |
+
if constexpr (kHasDfinalStates) {
|
| 513 |
+
if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
|
| 514 |
+
&& chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
|
| 515 |
+
&& chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 516 |
+
dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
| 517 |
+
}
|
| 518 |
+
}
|
| 519 |
+
}
|
| 520 |
+
|
| 521 |
+
float dxinit_vals[kWidth - 1] = {0};
|
| 522 |
+
static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
|
| 523 |
+
if (dinitial_states != nullptr && col_idx == 0) {
|
| 524 |
+
#pragma unroll
|
| 525 |
+
for (int i = 0; i < kWidth - 1; ++i) {
|
| 526 |
+
#pragma unroll
|
| 527 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 528 |
+
dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
|
| 529 |
+
}
|
| 530 |
+
// chunk_l_id must be 0 because dinitial_states != nullptr
|
| 531 |
+
// if (dfinal_states != nullptr) {
|
| 532 |
+
if constexpr (kHasDfinalStates) {
|
| 533 |
+
if (i >= params.seqlen) {
|
| 534 |
+
dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
|
| 535 |
+
}
|
| 536 |
+
}
|
| 537 |
+
}
|
| 538 |
+
}
|
| 539 |
+
|
| 540 |
+
__syncthreads();
|
| 541 |
+
#pragma unroll
|
| 542 |
+
for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
|
| 543 |
+
if (dinitial_states != nullptr && col_idx == 0) {
|
| 544 |
+
#pragma unroll
|
| 545 |
+
for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
|
| 546 |
+
}
|
| 547 |
+
__syncthreads();
|
| 548 |
+
|
| 549 |
+
#pragma unroll
|
| 550 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 551 |
+
input_t dx_vals_store[kNElts];
|
| 552 |
+
reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
|
| 553 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 554 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 555 |
+
*reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
|
| 556 |
+
}
|
| 557 |
+
}
|
| 558 |
+
if (dinitial_states != nullptr
|
| 559 |
+
&& l_idx < kWidth - 1
|
| 560 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 561 |
+
input_t dxinit_vals_store[kNElts];
|
| 562 |
+
reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
|
| 563 |
+
*reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
|
| 564 |
+
}
|
| 565 |
+
|
| 566 |
+
}
|
| 567 |
+
|
| 568 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 569 |
+
void causal_conv1d_channellast_bwd_launch(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 570 |
+
BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
|
| 571 |
+
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
| 572 |
+
BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
|
| 573 |
+
BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
|
| 574 |
+
// kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
|
| 575 |
+
static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
|
| 576 |
+
using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
|
| 577 |
+
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 578 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 579 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 580 |
+
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
| 581 |
+
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
| 582 |
+
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
| 583 |
+
dim3 block(Ktraits::kNThreads);
|
| 584 |
+
auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
|
| 585 |
+
// if (kSmemSize >= 48 * 1024) {
|
| 586 |
+
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 587 |
+
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 588 |
+
// }
|
| 589 |
+
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 590 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 591 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 592 |
+
});
|
| 593 |
+
});
|
| 594 |
+
});
|
| 595 |
+
});
|
| 596 |
+
}
|
| 597 |
+
|
| 598 |
+
template<typename input_t, typename weight_t>
|
| 599 |
+
void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd ¶ms, cudaStream_t stream) {
|
| 600 |
+
if (params.width == 2) {
|
| 601 |
+
causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 602 |
+
} else if (params.width == 3) {
|
| 603 |
+
causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 604 |
+
} else if (params.width == 4) {
|
| 605 |
+
causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 606 |
+
}
|
| 607 |
+
}
|
| 608 |
+
|
| 609 |
+
template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 610 |
+
template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 611 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 612 |
+
template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 613 |
+
template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 614 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 615 |
+
template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 616 |
+
template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 617 |
+
template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 618 |
+
|
| 619 |
+
template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 620 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 621 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 622 |
+
template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 623 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 624 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 625 |
+
template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 626 |
+
template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
| 627 |
+
template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd ¶ms, cudaStream_t stream);
|
causal-conv1d/causal_conv1d_common.h
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#pragma once
|
| 6 |
+
|
| 7 |
+
#ifndef USE_ROCM
|
| 8 |
+
#include <cuda_bf16.h>
|
| 9 |
+
|
| 10 |
+
template<typename T>
|
| 11 |
+
__device__ inline T shuffle_xor(T val, int offset) {
|
| 12 |
+
return __shfl_xor_sync(uint32_t(-1), val, offset);
|
| 13 |
+
}
|
| 14 |
+
|
| 15 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
| 16 |
+
{
|
| 17 |
+
return std::max(ilist);
|
| 18 |
+
}
|
| 19 |
+
|
| 20 |
+
template<typename T>
|
| 21 |
+
constexpr T constexpr_min(T a, T b) {
|
| 22 |
+
return std::min(a, b);
|
| 23 |
+
}
|
| 24 |
+
|
| 25 |
+
#else
|
| 26 |
+
#include <hip/hip_bf16.h>
|
| 27 |
+
|
| 28 |
+
template<typename T>
|
| 29 |
+
__device__ inline T shuffle_xor(T val, int offset) {
|
| 30 |
+
return __shfl_xor(val, offset);
|
| 31 |
+
}
|
| 32 |
+
constexpr size_t custom_max(std::initializer_list<size_t> ilist)
|
| 33 |
+
{
|
| 34 |
+
return *std::max_element(ilist.begin(), ilist.end());
|
| 35 |
+
}
|
| 36 |
+
|
| 37 |
+
template<typename T>
|
| 38 |
+
constexpr T constexpr_min(T a, T b) {
|
| 39 |
+
return a < b ? a : b;
|
| 40 |
+
}
|
| 41 |
+
#endif
|
| 42 |
+
#include <cuda_fp16.h>
|
| 43 |
+
|
| 44 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 45 |
+
|
| 46 |
+
template<int BYTES> struct BytesToType {};
|
| 47 |
+
|
| 48 |
+
template<> struct BytesToType<16> {
|
| 49 |
+
using Type = uint4;
|
| 50 |
+
static_assert(sizeof(Type) == 16);
|
| 51 |
+
};
|
| 52 |
+
|
| 53 |
+
template<> struct BytesToType<8> {
|
| 54 |
+
using Type = uint64_t;
|
| 55 |
+
static_assert(sizeof(Type) == 8);
|
| 56 |
+
};
|
| 57 |
+
|
| 58 |
+
template<> struct BytesToType<4> {
|
| 59 |
+
using Type = uint32_t;
|
| 60 |
+
static_assert(sizeof(Type) == 4);
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
template<> struct BytesToType<2> {
|
| 64 |
+
using Type = uint16_t;
|
| 65 |
+
static_assert(sizeof(Type) == 2);
|
| 66 |
+
};
|
| 67 |
+
|
| 68 |
+
template<> struct BytesToType<1> {
|
| 69 |
+
using Type = uint8_t;
|
| 70 |
+
static_assert(sizeof(Type) == 1);
|
| 71 |
+
};
|
| 72 |
+
|
| 73 |
+
////////////////////////////////////////////////////////////////////////////////////////////////////
|
| 74 |
+
|
| 75 |
+
template<typename T>
|
| 76 |
+
struct SumOp {
|
| 77 |
+
__device__ inline T operator()(T const & x, T const & y) { return x + y; }
|
| 78 |
+
};
|
| 79 |
+
|
| 80 |
+
template<int THREADS>
|
| 81 |
+
struct Allreduce {
|
| 82 |
+
static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
|
| 83 |
+
template<typename T, typename Operator>
|
| 84 |
+
static __device__ inline T run(T x, Operator &op) {
|
| 85 |
+
constexpr int OFFSET = THREADS / 2;
|
| 86 |
+
x = op(x, shuffle_xor(x, OFFSET));
|
| 87 |
+
return Allreduce<OFFSET>::run(x, op);
|
| 88 |
+
}
|
| 89 |
+
};
|
| 90 |
+
|
| 91 |
+
template<>
|
| 92 |
+
struct Allreduce<2> {
|
| 93 |
+
template<typename T, typename Operator>
|
| 94 |
+
static __device__ inline T run(T x, Operator &op) {
|
| 95 |
+
x = op(x, shuffle_xor(x, 1));
|
| 96 |
+
return x;
|
| 97 |
+
}
|
| 98 |
+
};
|
causal-conv1d/causal_conv1d_fwd.cu
ADDED
|
@@ -0,0 +1,399 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2024, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
+
|
| 9 |
+
#ifndef USE_ROCM
|
| 10 |
+
#include <cub/block/block_load.cuh>
|
| 11 |
+
#include <cub/block/block_store.cuh>
|
| 12 |
+
#else
|
| 13 |
+
#include <hipcub/hipcub.hpp>
|
| 14 |
+
namespace cub = hipcub;
|
| 15 |
+
#endif
|
| 16 |
+
|
| 17 |
+
#include "causal_conv1d.h"
|
| 18 |
+
#include "causal_conv1d_common.h"
|
| 19 |
+
#include "static_switch.h"
|
| 20 |
+
|
| 21 |
+
template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 22 |
+
struct Causal_conv1d_fwd_kernel_traits {
|
| 23 |
+
using input_t = input_t_;
|
| 24 |
+
using weight_t = weight_t_;
|
| 25 |
+
static constexpr int kNThreads = kNThreads_;
|
| 26 |
+
static constexpr int kWidth = kWidth_;
|
| 27 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 28 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 29 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 30 |
+
static_assert(kWidth <= kNElts);
|
| 31 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 32 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 33 |
+
using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 34 |
+
using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
|
| 35 |
+
using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 36 |
+
using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
|
| 37 |
+
static constexpr int kSmemIOSize = kIsVecLoad
|
| 38 |
+
? 0
|
| 39 |
+
: custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
|
| 40 |
+
static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
|
| 41 |
+
static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
|
| 42 |
+
};
|
| 43 |
+
|
| 44 |
+
template<typename Ktraits>
|
| 45 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 46 |
+
void causal_conv1d_fwd_kernel(ConvParamsBase params) {
|
| 47 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 48 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 49 |
+
constexpr int kNElts = Ktraits::kNElts;
|
| 50 |
+
static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
|
| 51 |
+
using input_t = typename Ktraits::input_t;
|
| 52 |
+
using vec_t = typename Ktraits::vec_t;
|
| 53 |
+
using weight_t = typename Ktraits::weight_t;
|
| 54 |
+
|
| 55 |
+
// Shared memory.
|
| 56 |
+
extern __shared__ char smem_[];
|
| 57 |
+
auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
|
| 58 |
+
auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
|
| 59 |
+
auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
|
| 60 |
+
auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
|
| 61 |
+
vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
|
| 62 |
+
|
| 63 |
+
const int tidx = threadIdx.x;
|
| 64 |
+
const int batch_id = blockIdx.x;
|
| 65 |
+
const int channel_id = blockIdx.y;
|
| 66 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 67 |
+
+ channel_id * params.x_c_stride;
|
| 68 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
| 69 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 70 |
+
+ channel_id * params.out_c_stride;
|
| 71 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
| 72 |
+
|
| 73 |
+
// Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
|
| 74 |
+
if (tidx == 0) {
|
| 75 |
+
input_t zeros[kNElts] = {0};
|
| 76 |
+
smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
float weight_vals[kWidth];
|
| 80 |
+
#pragma unroll
|
| 81 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
| 82 |
+
|
| 83 |
+
constexpr int kChunkSize = kNThreads * kNElts;
|
| 84 |
+
const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
|
| 85 |
+
for (int chunk = 0; chunk < n_chunks; ++chunk) {
|
| 86 |
+
input_t x_vals_load[2 * kNElts] = {0};
|
| 87 |
+
if constexpr(kIsVecLoad) {
|
| 88 |
+
typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 89 |
+
} else {
|
| 90 |
+
__syncthreads();
|
| 91 |
+
typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
|
| 92 |
+
}
|
| 93 |
+
x += kChunkSize;
|
| 94 |
+
__syncthreads();
|
| 95 |
+
// Thread kNThreads - 1 don't write yet, so that thread 0 can read
|
| 96 |
+
// the last elements of the previous chunk.
|
| 97 |
+
if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
| 98 |
+
__syncthreads();
|
| 99 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
|
| 100 |
+
__syncthreads();
|
| 101 |
+
// Now thread kNThreads - 1 can write the last elements of the current chunk.
|
| 102 |
+
if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
|
| 103 |
+
|
| 104 |
+
float x_vals[2 * kNElts];
|
| 105 |
+
#pragma unroll
|
| 106 |
+
for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
|
| 107 |
+
|
| 108 |
+
float out_vals[kNElts];
|
| 109 |
+
#pragma unroll
|
| 110 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 111 |
+
out_vals[i] = bias_val;
|
| 112 |
+
#pragma unroll
|
| 113 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 114 |
+
out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
|
| 115 |
+
}
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
if (params.silu_activation) {
|
| 119 |
+
#pragma unroll
|
| 120 |
+
for (int i = 0; i < kNElts; ++i) {
|
| 121 |
+
out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
|
| 122 |
+
}
|
| 123 |
+
}
|
| 124 |
+
|
| 125 |
+
input_t out_vals_store[kNElts];
|
| 126 |
+
#pragma unroll
|
| 127 |
+
for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
|
| 128 |
+
if constexpr(kIsVecLoad) {
|
| 129 |
+
typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
|
| 130 |
+
} else {
|
| 131 |
+
typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
|
| 132 |
+
}
|
| 133 |
+
out += kChunkSize;
|
| 134 |
+
}
|
| 135 |
+
}
|
| 136 |
+
|
| 137 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 138 |
+
void causal_conv1d_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 139 |
+
static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
|
| 140 |
+
BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
|
| 141 |
+
using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
|
| 142 |
+
constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 143 |
+
dim3 grid(params.batch, params.dim);
|
| 144 |
+
|
| 145 |
+
auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
|
| 146 |
+
|
| 147 |
+
if (kSmemSize >= 48 * 1024) {
|
| 148 |
+
#ifndef USE_ROCM
|
| 149 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 150 |
+
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 151 |
+
#else
|
| 152 |
+
// There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
|
| 153 |
+
C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 154 |
+
(void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 155 |
+
std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
|
| 156 |
+
#endif
|
| 157 |
+
}
|
| 158 |
+
kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 159 |
+
|
| 160 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 161 |
+
});
|
| 162 |
+
}
|
| 163 |
+
|
| 164 |
+
template<typename input_t, typename weight_t>
|
| 165 |
+
void causal_conv1d_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 166 |
+
if (params.width == 2) {
|
| 167 |
+
causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 168 |
+
} else if (params.width == 3) {
|
| 169 |
+
causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 170 |
+
} else if (params.width == 4) {
|
| 171 |
+
causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 172 |
+
}
|
| 173 |
+
}
|
| 174 |
+
|
| 175 |
+
template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
|
| 176 |
+
struct Causal_conv1d_channellast_fwd_kernel_traits {
|
| 177 |
+
// The cache line is 128 bytes, and we try to read 16 bytes per thread.
|
| 178 |
+
// So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
|
| 179 |
+
// That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
|
| 180 |
+
// threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
|
| 181 |
+
using input_t = input_t_;
|
| 182 |
+
using weight_t = weight_t_;
|
| 183 |
+
static constexpr int kNThreads = kNThreads_;
|
| 184 |
+
static_assert(kNThreads % 32 == 0);
|
| 185 |
+
static constexpr int kNWarps = kNThreads / 32;
|
| 186 |
+
static constexpr int kWidth = kWidth_;
|
| 187 |
+
static constexpr int kChunkSizeL = kChunkSizeL_;
|
| 188 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 189 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 190 |
+
static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
|
| 191 |
+
static constexpr int kNEltsPerRow = 128 / kNBytes;
|
| 192 |
+
static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
|
| 193 |
+
static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
|
| 194 |
+
static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
|
| 195 |
+
static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
|
| 196 |
+
static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
|
| 197 |
+
static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
|
| 198 |
+
static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
|
| 199 |
+
static constexpr bool kIsVecLoad = kIsVecLoad_;
|
| 200 |
+
using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
|
| 201 |
+
// using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
|
| 202 |
+
// using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
|
| 203 |
+
// static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
|
| 204 |
+
// sizeof(typename BlockStoreT::TempStorage)});
|
| 205 |
+
// static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
|
| 206 |
+
};
|
| 207 |
+
|
| 208 |
+
template<typename Ktraits, bool kHasSeqIdx>
|
| 209 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 210 |
+
void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
|
| 211 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 212 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 213 |
+
constexpr int kNElts = Ktraits::kNElts;
|
| 214 |
+
constexpr int kNWarp = Ktraits::kNWarps;
|
| 215 |
+
constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
|
| 216 |
+
constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
|
| 217 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 218 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 219 |
+
using input_t = typename Ktraits::input_t;
|
| 220 |
+
using vec_t = typename Ktraits::vec_t;
|
| 221 |
+
using weight_t = typename Ktraits::weight_t;
|
| 222 |
+
|
| 223 |
+
// Shared memory.
|
| 224 |
+
__shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
|
| 225 |
+
|
| 226 |
+
const int batch_id = blockIdx.x;
|
| 227 |
+
const int chunk_l_id = blockIdx.y;
|
| 228 |
+
const int chunk_c_id = blockIdx.z;
|
| 229 |
+
const int tid = threadIdx.x;
|
| 230 |
+
const int l_idx = tid / kNThreadsPerC;
|
| 231 |
+
const int c_idx = tid % kNThreadsPerC;
|
| 232 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 233 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 234 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
|
| 235 |
+
+ chunk_c_id * kChunkSizeC * params.weight_c_stride;
|
| 236 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 237 |
+
+ (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 238 |
+
int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
|
| 239 |
+
+ batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
|
| 240 |
+
input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
|
| 241 |
+
: reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 242 |
+
// The last L-chunk will also have enough info to write to final states, since it also contain a few x values
|
| 243 |
+
// from the previous L-chunk.
|
| 244 |
+
input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
|
| 245 |
+
: reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
|
| 246 |
+
|
| 247 |
+
#pragma unroll
|
| 248 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 249 |
+
input_t x_vals_load[kNElts] = {0};
|
| 250 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 251 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 252 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
|
| 253 |
+
}
|
| 254 |
+
reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 255 |
+
}
|
| 256 |
+
// Load the elements from the previous chunk that are needed for convolution.
|
| 257 |
+
if (l_idx < kWidth - 1) {
|
| 258 |
+
input_t x_vals_load[kNElts] = {0};
|
| 259 |
+
if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
|
| 260 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
|
| 261 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 262 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
|
| 263 |
+
} else if (initial_states != nullptr
|
| 264 |
+
&& chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
|
| 265 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 266 |
+
reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
|
| 267 |
+
}
|
| 268 |
+
reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
|
| 269 |
+
}
|
| 270 |
+
|
| 271 |
+
__syncthreads();
|
| 272 |
+
|
| 273 |
+
if (final_states != nullptr
|
| 274 |
+
&& l_idx < kWidth - 1
|
| 275 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 276 |
+
// x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
|
| 277 |
+
// So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
|
| 278 |
+
*reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
|
| 279 |
+
}
|
| 280 |
+
|
| 281 |
+
constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
|
| 282 |
+
static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
|
| 283 |
+
constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
|
| 284 |
+
static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
|
| 285 |
+
// kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
|
| 286 |
+
static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
|
| 287 |
+
static_assert((kLPerThread & (kLPerThread - 1)) == 0);
|
| 288 |
+
static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
|
| 289 |
+
static_assert(kNThreadsPerRow <= 32);
|
| 290 |
+
|
| 291 |
+
const int row_idx = tid / kNThreadsPerRow;
|
| 292 |
+
const int col_idx = tid % kNThreadsPerRow;
|
| 293 |
+
|
| 294 |
+
float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
|
| 295 |
+
float weight_vals[kWidth] = {0};
|
| 296 |
+
if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
|
| 297 |
+
#pragma unroll
|
| 298 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 299 |
+
weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
|
| 300 |
+
}
|
| 301 |
+
}
|
| 302 |
+
float x_vals[kWidth - 1 + kLPerThread];
|
| 303 |
+
#pragma unroll
|
| 304 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 305 |
+
x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
|
| 306 |
+
}
|
| 307 |
+
int seq_idx_thread[kWidth - 1 + kLPerThread];
|
| 308 |
+
if constexpr (kHasSeqIdx) {
|
| 309 |
+
#pragma unroll
|
| 310 |
+
for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
|
| 311 |
+
seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
|
| 312 |
+
}
|
| 313 |
+
}
|
| 314 |
+
|
| 315 |
+
float out_vals[kLPerThread];
|
| 316 |
+
#pragma unroll
|
| 317 |
+
for (int i = 0; i < kLPerThread; ++i) {
|
| 318 |
+
out_vals[i] = bias_val;
|
| 319 |
+
const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
|
| 320 |
+
#pragma unroll
|
| 321 |
+
for (int w = 0; w < kWidth; ++w) {
|
| 322 |
+
if constexpr (!kHasSeqIdx) {
|
| 323 |
+
out_vals[i] += weight_vals[w] * x_vals[i + w];
|
| 324 |
+
} else {
|
| 325 |
+
out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
|
| 326 |
+
}
|
| 327 |
+
}
|
| 328 |
+
if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
|
| 329 |
+
}
|
| 330 |
+
|
| 331 |
+
__syncthreads();
|
| 332 |
+
#pragma unroll
|
| 333 |
+
for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
|
| 334 |
+
__syncthreads();
|
| 335 |
+
|
| 336 |
+
#pragma unroll
|
| 337 |
+
for (int l = 0; l < Ktraits::kNLoads; ++l) {
|
| 338 |
+
input_t out_vals_store[kNElts];
|
| 339 |
+
reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
|
| 340 |
+
if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
|
| 341 |
+
&& chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
|
| 342 |
+
*reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
|
| 343 |
+
}
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 349 |
+
void causal_conv1d_channellast_fwd_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 350 |
+
BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
|
| 351 |
+
using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
|
| 352 |
+
// constexpr int kSmemSize = Ktraits::kSmemSize;
|
| 353 |
+
constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
|
| 354 |
+
constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
|
| 355 |
+
const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
|
| 356 |
+
const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
|
| 357 |
+
dim3 grid(params.batch, n_chunks_L, n_chunks_C);
|
| 358 |
+
dim3 block(Ktraits::kNThreads);
|
| 359 |
+
auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
|
| 360 |
+
// if (kSmemSize >= 48 * 1024) {
|
| 361 |
+
// C10_CUDA_CHECK(cudaFuncSetAttribute(
|
| 362 |
+
// kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
|
| 363 |
+
// }
|
| 364 |
+
// kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
|
| 365 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 366 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 367 |
+
});
|
| 368 |
+
}
|
| 369 |
+
|
| 370 |
+
template<typename input_t, typename weight_t>
|
| 371 |
+
void causal_conv1d_channellast_fwd_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 372 |
+
if (params.width == 2) {
|
| 373 |
+
causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
|
| 374 |
+
} else if (params.width == 3) {
|
| 375 |
+
causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
|
| 376 |
+
} else if (params.width == 4) {
|
| 377 |
+
causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
|
| 378 |
+
}
|
| 379 |
+
}
|
| 380 |
+
|
| 381 |
+
template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 382 |
+
template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 383 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 384 |
+
template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 385 |
+
template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 386 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 387 |
+
template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 388 |
+
template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 389 |
+
template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 390 |
+
|
| 391 |
+
template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 392 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 393 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 394 |
+
template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 395 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 396 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 397 |
+
template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 398 |
+
template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 399 |
+
template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
causal-conv1d/causal_conv1d_update.cu
ADDED
|
@@ -0,0 +1,137 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
/******************************************************************************
|
| 2 |
+
* Copyright (c) 2023, Tri Dao.
|
| 3 |
+
******************************************************************************/
|
| 4 |
+
|
| 5 |
+
#include <c10/util/BFloat16.h>
|
| 6 |
+
#include <c10/util/Half.h>
|
| 7 |
+
#include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
|
| 8 |
+
|
| 9 |
+
#include "causal_conv1d.h"
|
| 10 |
+
#include "causal_conv1d_common.h"
|
| 11 |
+
#include "static_switch.h"
|
| 12 |
+
|
| 13 |
+
template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
|
| 14 |
+
struct Causal_conv1d_update_kernel_traits {
|
| 15 |
+
using input_t = input_t_;
|
| 16 |
+
using weight_t = weight_t_;
|
| 17 |
+
static constexpr int kNThreads = kNThreads_;
|
| 18 |
+
static constexpr int kWidth = kWidth_;
|
| 19 |
+
static constexpr int kNBytes = sizeof(input_t);
|
| 20 |
+
static_assert(kNBytes == 2 || kNBytes == 4);
|
| 21 |
+
};
|
| 22 |
+
|
| 23 |
+
template<typename Ktraits, bool kIsCircularBuffer>
|
| 24 |
+
__global__ __launch_bounds__(Ktraits::kNThreads)
|
| 25 |
+
void causal_conv1d_update_kernel(ConvParamsBase params) {
|
| 26 |
+
constexpr int kWidth = Ktraits::kWidth;
|
| 27 |
+
constexpr int kNThreads = Ktraits::kNThreads;
|
| 28 |
+
using input_t = typename Ktraits::input_t;
|
| 29 |
+
using weight_t = typename Ktraits::weight_t;
|
| 30 |
+
|
| 31 |
+
const int tidx = threadIdx.x;
|
| 32 |
+
const int batch_id = blockIdx.x;
|
| 33 |
+
const int channel_id = blockIdx.y * kNThreads + tidx;
|
| 34 |
+
if (channel_id >= params.dim) return;
|
| 35 |
+
|
| 36 |
+
input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
|
| 37 |
+
+ channel_id * params.x_c_stride;
|
| 38 |
+
|
| 39 |
+
// If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
|
| 40 |
+
// along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
|
| 41 |
+
const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
|
| 42 |
+
? batch_id
|
| 43 |
+
: params.conv_state_indices_ptr[batch_id];
|
| 44 |
+
input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
|
| 45 |
+
+ conv_state_batch_coord * params.conv_state_batch_stride
|
| 46 |
+
+ channel_id * params.conv_state_c_stride;
|
| 47 |
+
weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
|
| 48 |
+
input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
|
| 49 |
+
+ channel_id * params.out_c_stride;
|
| 50 |
+
float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
|
| 51 |
+
|
| 52 |
+
int state_len = params.conv_state_len;
|
| 53 |
+
int advance_len = params.seqlen;
|
| 54 |
+
int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
|
| 55 |
+
int update_idx = cache_seqlen - (kWidth - 1);
|
| 56 |
+
update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
|
| 57 |
+
|
| 58 |
+
float weight_vals[kWidth] = {0};
|
| 59 |
+
#pragma unroll
|
| 60 |
+
for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
|
| 61 |
+
|
| 62 |
+
float x_vals[kWidth] = {0};
|
| 63 |
+
if constexpr (!kIsCircularBuffer) {
|
| 64 |
+
#pragma unroll 2
|
| 65 |
+
for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
|
| 66 |
+
conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
|
| 67 |
+
}
|
| 68 |
+
#pragma unroll
|
| 69 |
+
for (int i = 0; i < kWidth - 1; ++i) {
|
| 70 |
+
input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
|
| 71 |
+
if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
|
| 72 |
+
conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
|
| 73 |
+
}
|
| 74 |
+
x_vals[i] = float(state_val);
|
| 75 |
+
}
|
| 76 |
+
} else {
|
| 77 |
+
#pragma unroll
|
| 78 |
+
for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
|
| 79 |
+
input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
|
| 80 |
+
x_vals[i] = float(state_val);
|
| 81 |
+
}
|
| 82 |
+
}
|
| 83 |
+
#pragma unroll 2
|
| 84 |
+
for (int i = 0; i < params.seqlen; ++i) {
|
| 85 |
+
input_t x_val = x[i * params.x_l_stride];
|
| 86 |
+
if constexpr (!kIsCircularBuffer) {
|
| 87 |
+
if (i < advance_len && state_len - advance_len + i >= 0) {
|
| 88 |
+
conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
|
| 89 |
+
}
|
| 90 |
+
} else {
|
| 91 |
+
conv_state[update_idx * params.conv_state_l_stride] = x_val;
|
| 92 |
+
++update_idx;
|
| 93 |
+
update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
|
| 94 |
+
}
|
| 95 |
+
x_vals[kWidth - 1] = float(x_val);
|
| 96 |
+
float out_val = bias_val;
|
| 97 |
+
#pragma unroll
|
| 98 |
+
for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
|
| 99 |
+
if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
|
| 100 |
+
out[i * params.out_l_stride] = input_t(out_val);
|
| 101 |
+
// Shift the input buffer by 1
|
| 102 |
+
#pragma unroll
|
| 103 |
+
for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
|
| 104 |
+
}
|
| 105 |
+
}
|
| 106 |
+
|
| 107 |
+
template<int kNThreads, int kWidth, typename input_t, typename weight_t>
|
| 108 |
+
void causal_conv1d_update_launch(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 109 |
+
using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
|
| 110 |
+
dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
|
| 111 |
+
auto kernel = params.cache_seqlens == nullptr
|
| 112 |
+
? &causal_conv1d_update_kernel<Ktraits, false>
|
| 113 |
+
: &causal_conv1d_update_kernel<Ktraits, true>;
|
| 114 |
+
kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
|
| 115 |
+
C10_CUDA_KERNEL_LAUNCH_CHECK();
|
| 116 |
+
}
|
| 117 |
+
|
| 118 |
+
template<typename input_t, typename weight_t>
|
| 119 |
+
void causal_conv1d_update_cuda(ConvParamsBase ¶ms, cudaStream_t stream) {
|
| 120 |
+
if (params.width == 2) {
|
| 121 |
+
causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
|
| 122 |
+
} else if (params.width == 3) {
|
| 123 |
+
causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
|
| 124 |
+
} else if (params.width == 4) {
|
| 125 |
+
causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
|
| 126 |
+
}
|
| 127 |
+
}
|
| 128 |
+
|
| 129 |
+
template void causal_conv1d_update_cuda<float, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 130 |
+
template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 131 |
+
template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 132 |
+
template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 133 |
+
template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 134 |
+
template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 135 |
+
template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 136 |
+
template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
| 137 |
+
template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase ¶ms, cudaStream_t stream);
|
causal-conv1d/static_switch.h
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
|
| 2 |
+
// and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
|
| 3 |
+
|
| 4 |
+
#pragma once
|
| 5 |
+
|
| 6 |
+
/// @param COND - a boolean expression to switch by
|
| 7 |
+
/// @param CONST_NAME - a name given for the constexpr bool variable.
|
| 8 |
+
/// @param ... - code to execute for true and false
|
| 9 |
+
///
|
| 10 |
+
/// Usage:
|
| 11 |
+
/// ```
|
| 12 |
+
/// BOOL_SWITCH(flag, BoolConst, [&] {
|
| 13 |
+
/// some_function<BoolConst>(...);
|
| 14 |
+
/// });
|
| 15 |
+
/// ```
|
| 16 |
+
#define BOOL_SWITCH(COND, CONST_NAME, ...) \
|
| 17 |
+
[&] { \
|
| 18 |
+
if (COND) { \
|
| 19 |
+
static constexpr bool CONST_NAME = true; \
|
| 20 |
+
return __VA_ARGS__(); \
|
| 21 |
+
} else { \
|
| 22 |
+
static constexpr bool CONST_NAME = false; \
|
| 23 |
+
return __VA_ARGS__(); \
|
| 24 |
+
} \
|
| 25 |
+
}()
|
flake.lock
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"nodes": {
|
| 3 |
+
"flake-compat": {
|
| 4 |
+
"locked": {
|
| 5 |
+
"lastModified": 1747046372,
|
| 6 |
+
"narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
|
| 7 |
+
"owner": "edolstra",
|
| 8 |
+
"repo": "flake-compat",
|
| 9 |
+
"rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
|
| 10 |
+
"type": "github"
|
| 11 |
+
},
|
| 12 |
+
"original": {
|
| 13 |
+
"owner": "edolstra",
|
| 14 |
+
"repo": "flake-compat",
|
| 15 |
+
"type": "github"
|
| 16 |
+
}
|
| 17 |
+
},
|
| 18 |
+
"flake-compat_2": {
|
| 19 |
+
"locked": {
|
| 20 |
+
"lastModified": 1733328505,
|
| 21 |
+
"narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
|
| 22 |
+
"owner": "edolstra",
|
| 23 |
+
"repo": "flake-compat",
|
| 24 |
+
"rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
|
| 25 |
+
"type": "github"
|
| 26 |
+
},
|
| 27 |
+
"original": {
|
| 28 |
+
"owner": "edolstra",
|
| 29 |
+
"repo": "flake-compat",
|
| 30 |
+
"type": "github"
|
| 31 |
+
}
|
| 32 |
+
},
|
| 33 |
+
"flake-utils": {
|
| 34 |
+
"inputs": {
|
| 35 |
+
"systems": "systems"
|
| 36 |
+
},
|
| 37 |
+
"locked": {
|
| 38 |
+
"lastModified": 1731533236,
|
| 39 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 40 |
+
"owner": "numtide",
|
| 41 |
+
"repo": "flake-utils",
|
| 42 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 43 |
+
"type": "github"
|
| 44 |
+
},
|
| 45 |
+
"original": {
|
| 46 |
+
"owner": "numtide",
|
| 47 |
+
"repo": "flake-utils",
|
| 48 |
+
"type": "github"
|
| 49 |
+
}
|
| 50 |
+
},
|
| 51 |
+
"flake-utils_2": {
|
| 52 |
+
"inputs": {
|
| 53 |
+
"systems": "systems_2"
|
| 54 |
+
},
|
| 55 |
+
"locked": {
|
| 56 |
+
"lastModified": 1731533236,
|
| 57 |
+
"narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
|
| 58 |
+
"owner": "numtide",
|
| 59 |
+
"repo": "flake-utils",
|
| 60 |
+
"rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
|
| 61 |
+
"type": "github"
|
| 62 |
+
},
|
| 63 |
+
"original": {
|
| 64 |
+
"owner": "numtide",
|
| 65 |
+
"repo": "flake-utils",
|
| 66 |
+
"type": "github"
|
| 67 |
+
}
|
| 68 |
+
},
|
| 69 |
+
"hf-nix": {
|
| 70 |
+
"inputs": {
|
| 71 |
+
"flake-compat": "flake-compat_2",
|
| 72 |
+
"flake-utils": "flake-utils_2",
|
| 73 |
+
"nixpkgs": "nixpkgs"
|
| 74 |
+
},
|
| 75 |
+
"locked": {
|
| 76 |
+
"lastModified": 1754038838,
|
| 77 |
+
"narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
|
| 78 |
+
"owner": "huggingface",
|
| 79 |
+
"repo": "hf-nix",
|
| 80 |
+
"rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
|
| 81 |
+
"type": "github"
|
| 82 |
+
},
|
| 83 |
+
"original": {
|
| 84 |
+
"owner": "huggingface",
|
| 85 |
+
"repo": "hf-nix",
|
| 86 |
+
"type": "github"
|
| 87 |
+
}
|
| 88 |
+
},
|
| 89 |
+
"kernel-builder": {
|
| 90 |
+
"inputs": {
|
| 91 |
+
"flake-compat": "flake-compat",
|
| 92 |
+
"flake-utils": "flake-utils",
|
| 93 |
+
"hf-nix": "hf-nix",
|
| 94 |
+
"nixpkgs": [
|
| 95 |
+
"kernel-builder",
|
| 96 |
+
"hf-nix",
|
| 97 |
+
"nixpkgs"
|
| 98 |
+
]
|
| 99 |
+
},
|
| 100 |
+
"locked": {
|
| 101 |
+
"lastModified": 1755181472,
|
| 102 |
+
"narHash": "sha256-xOXjhehC5xi/XB4fXZ5c0L2sSyDjJQdlH7/BcdHLBaM=",
|
| 103 |
+
"owner": "huggingface",
|
| 104 |
+
"repo": "kernel-builder",
|
| 105 |
+
"rev": "85da46f660c1c43b40771c3df3b223bb3fa39bec",
|
| 106 |
+
"type": "github"
|
| 107 |
+
},
|
| 108 |
+
"original": {
|
| 109 |
+
"owner": "huggingface",
|
| 110 |
+
"repo": "kernel-builder",
|
| 111 |
+
"type": "github"
|
| 112 |
+
}
|
| 113 |
+
},
|
| 114 |
+
"nixpkgs": {
|
| 115 |
+
"locked": {
|
| 116 |
+
"lastModified": 1752785354,
|
| 117 |
+
"narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
|
| 118 |
+
"owner": "nixos",
|
| 119 |
+
"repo": "nixpkgs",
|
| 120 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
| 121 |
+
"type": "github"
|
| 122 |
+
},
|
| 123 |
+
"original": {
|
| 124 |
+
"owner": "nixos",
|
| 125 |
+
"repo": "nixpkgs",
|
| 126 |
+
"rev": "d38025438a6ee456758dc03188ca6873a415463b",
|
| 127 |
+
"type": "github"
|
| 128 |
+
}
|
| 129 |
+
},
|
| 130 |
+
"root": {
|
| 131 |
+
"inputs": {
|
| 132 |
+
"kernel-builder": "kernel-builder"
|
| 133 |
+
}
|
| 134 |
+
},
|
| 135 |
+
"systems": {
|
| 136 |
+
"locked": {
|
| 137 |
+
"lastModified": 1681028828,
|
| 138 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 139 |
+
"owner": "nix-systems",
|
| 140 |
+
"repo": "default",
|
| 141 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 142 |
+
"type": "github"
|
| 143 |
+
},
|
| 144 |
+
"original": {
|
| 145 |
+
"owner": "nix-systems",
|
| 146 |
+
"repo": "default",
|
| 147 |
+
"type": "github"
|
| 148 |
+
}
|
| 149 |
+
},
|
| 150 |
+
"systems_2": {
|
| 151 |
+
"locked": {
|
| 152 |
+
"lastModified": 1681028828,
|
| 153 |
+
"narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
|
| 154 |
+
"owner": "nix-systems",
|
| 155 |
+
"repo": "default",
|
| 156 |
+
"rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
|
| 157 |
+
"type": "github"
|
| 158 |
+
},
|
| 159 |
+
"original": {
|
| 160 |
+
"owner": "nix-systems",
|
| 161 |
+
"repo": "default",
|
| 162 |
+
"type": "github"
|
| 163 |
+
}
|
| 164 |
+
}
|
| 165 |
+
},
|
| 166 |
+
"root": "root",
|
| 167 |
+
"version": 7
|
| 168 |
+
}
|
flake.nix
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
description = "Flake for attention kernels";
|
| 3 |
+
|
| 4 |
+
inputs = {
|
| 5 |
+
kernel-builder.url = "github:huggingface/kernel-builder";
|
| 6 |
+
};
|
| 7 |
+
|
| 8 |
+
outputs =
|
| 9 |
+
{
|
| 10 |
+
self,
|
| 11 |
+
kernel-builder,
|
| 12 |
+
}:
|
| 13 |
+
kernel-builder.lib.genFlakeOutputs {
|
| 14 |
+
path = ./.;
|
| 15 |
+
rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
|
| 16 |
+
pythonCheckInputs = pkgs: with pkgs; [ einops ];
|
| 17 |
+
};
|
| 18 |
+
}
|
tests/test_causal_conv1d.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (C) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import math
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
import pytest
|
| 9 |
+
|
| 10 |
+
from einops import rearrange
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
from causal_conv1d import causal_conv1d_fn, causal_conv1d_update, causal_conv1d_varlen_states
|
| 14 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_ref
|
| 15 |
+
from causal_conv1d.causal_conv1d_interface import causal_conv1d_update_ref
|
| 16 |
+
from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states_ref
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@pytest.mark.parametrize("return_final_states", [False, True])
|
| 20 |
+
# @pytest.mark.parametrize("return_final_states", [True])
|
| 21 |
+
@pytest.mark.parametrize("has_initial_states", [False, True])
|
| 22 |
+
# @pytest.mark.parametrize("has_initial_states", [False])
|
| 23 |
+
@pytest.mark.parametrize("channel_last", [False, True])
|
| 24 |
+
# @pytest.mark.parametrize('channel_last', [True])
|
| 25 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 26 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 27 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 28 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
| 29 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 30 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
| 31 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 32 |
+
# @pytest.mark.parametrize('width', [3])
|
| 33 |
+
@pytest.mark.parametrize(
|
| 34 |
+
"seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 35 |
+
)
|
| 36 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 37 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
| 38 |
+
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
| 39 |
+
# @pytest.mark.parametrize('dim', [64])
|
| 40 |
+
def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
|
| 41 |
+
if not channel_last and (has_initial_states or return_final_states):
|
| 42 |
+
pytest.skip("Only channel_last support initial_states or return_final_states")
|
| 43 |
+
device = "cuda"
|
| 44 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 45 |
+
if itype == torch.bfloat16:
|
| 46 |
+
rtol, atol = 1e-2, 5e-2
|
| 47 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 48 |
+
# set seed
|
| 49 |
+
torch.random.manual_seed(0)
|
| 50 |
+
batch = 2
|
| 51 |
+
# batch = 1
|
| 52 |
+
if not channel_last:
|
| 53 |
+
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
| 54 |
+
else:
|
| 55 |
+
x = rearrange(
|
| 56 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 57 |
+
).requires_grad_()
|
| 58 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 59 |
+
if has_bias:
|
| 60 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 61 |
+
else:
|
| 62 |
+
bias = None
|
| 63 |
+
if has_initial_states:
|
| 64 |
+
initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
|
| 65 |
+
else:
|
| 66 |
+
initial_states = None
|
| 67 |
+
x_ref = x.detach().clone().requires_grad_()
|
| 68 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
| 69 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
| 70 |
+
initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
|
| 71 |
+
activation = None if not silu_activation else "silu"
|
| 72 |
+
out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
|
| 73 |
+
activation=activation)
|
| 74 |
+
out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
|
| 75 |
+
if return_final_states:
|
| 76 |
+
out, final_states = out
|
| 77 |
+
out_ref, final_states_ref = out_ref
|
| 78 |
+
print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
|
| 79 |
+
print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
|
| 80 |
+
assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
|
| 81 |
+
|
| 82 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 83 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 84 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 85 |
+
|
| 86 |
+
if return_final_states:
|
| 87 |
+
out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
|
| 88 |
+
out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
|
| 89 |
+
|
| 90 |
+
g = torch.randn_like(out)
|
| 91 |
+
out.backward(g)
|
| 92 |
+
out_ref.backward(g)
|
| 93 |
+
|
| 94 |
+
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
| 95 |
+
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
| 96 |
+
if has_bias:
|
| 97 |
+
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
| 98 |
+
if has_initial_states:
|
| 99 |
+
print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
|
| 100 |
+
|
| 101 |
+
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 102 |
+
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
| 103 |
+
if has_bias:
|
| 104 |
+
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
| 105 |
+
if has_initial_states:
|
| 106 |
+
assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 110 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 111 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 112 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
| 113 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 114 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
| 115 |
+
@pytest.mark.parametrize("has_cache_seqlens", [False, True])
|
| 116 |
+
# @pytest.mark.parametrize('has_cache_seqlens', [True])
|
| 117 |
+
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
| 118 |
+
# @pytest.mark.parametrize('seqlen', [4])
|
| 119 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 120 |
+
# @pytest.mark.parametrize('width', [4])
|
| 121 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 122 |
+
# @pytest.mark.parametrize("dim", [2048])
|
| 123 |
+
def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
|
| 124 |
+
device = "cuda"
|
| 125 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 126 |
+
if itype == torch.bfloat16:
|
| 127 |
+
rtol, atol = 1e-2, 5e-2
|
| 128 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 129 |
+
# set seed
|
| 130 |
+
torch.random.manual_seed(0)
|
| 131 |
+
batch = 64
|
| 132 |
+
# batch = 1
|
| 133 |
+
# dim = 64
|
| 134 |
+
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 135 |
+
state_len = torch.randint(width - 1, width + 10, (1,)).item()
|
| 136 |
+
conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 137 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 138 |
+
if has_bias:
|
| 139 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 140 |
+
else:
|
| 141 |
+
bias = None
|
| 142 |
+
conv_state_ref = conv_state.detach().clone()
|
| 143 |
+
activation = None if not silu_activation else "silu"
|
| 144 |
+
cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
|
| 145 |
+
if has_cache_seqlens else None)
|
| 146 |
+
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 147 |
+
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 148 |
+
|
| 149 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 150 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 151 |
+
assert torch.equal(conv_state, conv_state_ref)
|
| 152 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 153 |
+
|
| 154 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 155 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 156 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 157 |
+
# @pytest.mark.parametrize('silu_activation', [True])
|
| 158 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 159 |
+
# @pytest.mark.parametrize('has_bias', [True])
|
| 160 |
+
@pytest.mark.parametrize("has_cache_seqlens", [False, True])
|
| 161 |
+
# @pytest.mark.parametrize('has_cache_seqlens', [True])
|
| 162 |
+
@pytest.mark.parametrize("seqlen", [1, 4, 5])
|
| 163 |
+
# @pytest.mark.parametrize('seqlen', [4])
|
| 164 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 165 |
+
# @pytest.mark.parametrize('width', [4])
|
| 166 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 167 |
+
# @pytest.mark.parametrize("dim", [2048])
|
| 168 |
+
def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
|
| 169 |
+
device = "cuda"
|
| 170 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 171 |
+
if itype == torch.bfloat16:
|
| 172 |
+
rtol, atol = 1e-2, 5e-2
|
| 173 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 174 |
+
# set seed
|
| 175 |
+
torch.random.manual_seed(0)
|
| 176 |
+
batch = 64
|
| 177 |
+
# batch = 1
|
| 178 |
+
# dim = 64
|
| 179 |
+
x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 180 |
+
state_len = torch.randint(width - 1, width + 10, (1,)).item()
|
| 181 |
+
|
| 182 |
+
total_entries = 10 * batch
|
| 183 |
+
conv_state = torch.randn(total_entries, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
|
| 184 |
+
conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32, device=device)
|
| 185 |
+
|
| 186 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 187 |
+
if has_bias:
|
| 188 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 189 |
+
else:
|
| 190 |
+
bias = None
|
| 191 |
+
conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
|
| 192 |
+
activation = None if not silu_activation else "silu"
|
| 193 |
+
cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
|
| 194 |
+
if has_cache_seqlens else None)
|
| 195 |
+
out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation,
|
| 196 |
+
cache_seqlens=cache_seqlens, conv_state_indices=conv_state_indices)
|
| 197 |
+
out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
|
| 198 |
+
|
| 199 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 200 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 201 |
+
assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
|
| 202 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 206 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 207 |
+
@pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
|
| 208 |
+
# @pytest.mark.parametrize("dim", [2048])
|
| 209 |
+
def test_causal_conv1d_get_states(dim, itype):
|
| 210 |
+
device = "cuda"
|
| 211 |
+
# set seed
|
| 212 |
+
torch.random.manual_seed(0)
|
| 213 |
+
seqlens = torch.randint(1, 32, (100,), device=device)
|
| 214 |
+
total_seqlen = seqlens.sum().item()
|
| 215 |
+
x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
|
| 216 |
+
cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
|
| 217 |
+
state_len = 20
|
| 218 |
+
out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
|
| 219 |
+
out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
|
| 220 |
+
assert torch.equal(out, out_ref)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
# @pytest.mark.parametrize("channel_last", [False, True])
|
| 224 |
+
@pytest.mark.parametrize('channel_last', [True])
|
| 225 |
+
# @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 226 |
+
@pytest.mark.parametrize('itype', [torch.bfloat16])
|
| 227 |
+
# @pytest.mark.parametrize("silu_activation", [False, True])
|
| 228 |
+
@pytest.mark.parametrize('silu_activation', [True])
|
| 229 |
+
# @pytest.mark.parametrize("has_bias", [False, True])
|
| 230 |
+
@pytest.mark.parametrize('has_bias', [True])
|
| 231 |
+
# @pytest.mark.parametrize("width", [2, 3, 4])
|
| 232 |
+
@pytest.mark.parametrize('width', [4])
|
| 233 |
+
@pytest.mark.parametrize(
|
| 234 |
+
# "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 235 |
+
"seqlen", [2048]
|
| 236 |
+
)
|
| 237 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 238 |
+
# @pytest.mark.parametrize('seqlen', [128])
|
| 239 |
+
def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
|
| 240 |
+
device = "cuda"
|
| 241 |
+
# set seed
|
| 242 |
+
torch.random.manual_seed(0)
|
| 243 |
+
batch = 2
|
| 244 |
+
# batch = 1
|
| 245 |
+
dim = 4096 + 32 # Try dim not divisible by 64
|
| 246 |
+
# dim = 64
|
| 247 |
+
if not channel_last:
|
| 248 |
+
x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
|
| 249 |
+
else:
|
| 250 |
+
x = rearrange(
|
| 251 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 252 |
+
).requires_grad_()
|
| 253 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 254 |
+
if has_bias:
|
| 255 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 256 |
+
else:
|
| 257 |
+
bias = None
|
| 258 |
+
activation = None if not silu_activation else "silu"
|
| 259 |
+
out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
|
| 260 |
+
g = torch.randn_like(out0)
|
| 261 |
+
dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
|
| 262 |
+
dw_atol = 1e-4
|
| 263 |
+
db_atol = 1e-4
|
| 264 |
+
|
| 265 |
+
for i in range(10000):
|
| 266 |
+
out = causal_conv1d_fn(x, weight, bias, activation=activation)
|
| 267 |
+
dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
|
| 268 |
+
dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
|
| 269 |
+
# if not dw_equal:
|
| 270 |
+
# breakpoint()
|
| 271 |
+
if has_bias:
|
| 272 |
+
db_equal = torch.allclose(db, db0, atol=db_atol)
|
| 273 |
+
# if not db_equal:
|
| 274 |
+
# breakpoint()
|
| 275 |
+
assert torch.equal(out, out0)
|
| 276 |
+
assert torch.equal(dx, dx0)
|
| 277 |
+
assert dw_equal
|
| 278 |
+
if has_bias:
|
| 279 |
+
assert dw_equal
|
| 280 |
+
|
| 281 |
+
|
| 282 |
+
@pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
|
| 283 |
+
# @pytest.mark.parametrize('itype', [torch.float16])
|
| 284 |
+
@pytest.mark.parametrize("silu_activation", [False, True])
|
| 285 |
+
# @pytest.mark.parametrize('silu_activation', [False])
|
| 286 |
+
@pytest.mark.parametrize("has_bias", [False, True])
|
| 287 |
+
# @pytest.mark.parametrize('has_bias', [False])
|
| 288 |
+
@pytest.mark.parametrize("width", [2, 3, 4])
|
| 289 |
+
# @pytest.mark.parametrize('width', [2])
|
| 290 |
+
@pytest.mark.parametrize(
|
| 291 |
+
"seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
|
| 292 |
+
)
|
| 293 |
+
# @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
|
| 294 |
+
# @pytest.mark.parametrize('seqlen', [2048])
|
| 295 |
+
@pytest.mark.parametrize('dim', [64, 4096 + 32])
|
| 296 |
+
# @pytest.mark.parametrize('dim', [64])
|
| 297 |
+
def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
|
| 298 |
+
device = "cuda"
|
| 299 |
+
rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
|
| 300 |
+
if itype == torch.bfloat16:
|
| 301 |
+
rtol, atol = 1e-2, 5e-2
|
| 302 |
+
rtolw, atolw = (1e-3, 1e-3)
|
| 303 |
+
# set seed
|
| 304 |
+
torch.random.manual_seed(seqlen + dim + width)
|
| 305 |
+
batch = 3
|
| 306 |
+
seqlens = []
|
| 307 |
+
for b in range(batch):
|
| 308 |
+
nsplits = torch.randint(1, 5, (1,)).item()
|
| 309 |
+
eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
|
| 310 |
+
seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
|
| 311 |
+
assert sum(seqlens[-1]) == seqlen
|
| 312 |
+
assert all(s > 0 for s in seqlens[-1])
|
| 313 |
+
# Only support channel_last
|
| 314 |
+
x = rearrange(
|
| 315 |
+
torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
|
| 316 |
+
).requires_grad_()
|
| 317 |
+
weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
|
| 318 |
+
if has_bias:
|
| 319 |
+
bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
|
| 320 |
+
else:
|
| 321 |
+
bias = None
|
| 322 |
+
seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
|
| 323 |
+
for sl in seqlens], dim=0)
|
| 324 |
+
x_ref = x.detach().clone().requires_grad_()
|
| 325 |
+
weight_ref = weight.detach().clone().requires_grad_()
|
| 326 |
+
bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
|
| 327 |
+
activation = None if not silu_activation else "silu"
|
| 328 |
+
out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
|
| 329 |
+
out_ref = []
|
| 330 |
+
for b in range(batch):
|
| 331 |
+
out_ref_b = []
|
| 332 |
+
for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
|
| 333 |
+
out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
|
| 334 |
+
out_ref.append(torch.cat(out_ref_b, dim=2))
|
| 335 |
+
out_ref = torch.cat(out_ref, dim=0)
|
| 336 |
+
|
| 337 |
+
print(f"Output max diff: {(out - out_ref).abs().max().item()}")
|
| 338 |
+
print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
|
| 339 |
+
assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
|
| 340 |
+
|
| 341 |
+
g = torch.randn_like(out)
|
| 342 |
+
out_ref.backward(g)
|
| 343 |
+
out.backward(g)
|
| 344 |
+
|
| 345 |
+
print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
|
| 346 |
+
print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
|
| 347 |
+
if has_bias:
|
| 348 |
+
print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
|
| 349 |
+
|
| 350 |
+
assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
|
| 351 |
+
assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
|
| 352 |
+
if has_bias:
|
| 353 |
+
assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
|
torch-ext/causal_conv1d/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
|
| 2 |
+
from .causal_conv1d_varlen import causal_conv1d_varlen_states
|
| 3 |
+
|
| 4 |
+
__all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
|
torch-ext/causal_conv1d/causal_conv1d_interface.py
ADDED
|
@@ -0,0 +1,242 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class CausalConv1dFn(torch.autograd.Function):
|
| 10 |
+
@staticmethod
|
| 11 |
+
def forward(
|
| 12 |
+
ctx,
|
| 13 |
+
x,
|
| 14 |
+
weight,
|
| 15 |
+
bias=None,
|
| 16 |
+
seq_idx=None,
|
| 17 |
+
initial_states=None,
|
| 18 |
+
return_final_states=False,
|
| 19 |
+
final_states_out=None,
|
| 20 |
+
activation=None,
|
| 21 |
+
):
|
| 22 |
+
if activation not in [None, "silu", "swish"]:
|
| 23 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 24 |
+
if x.stride(2) != 1 and x.stride(1) != 1:
|
| 25 |
+
x = x.contiguous()
|
| 26 |
+
bias = bias.contiguous() if bias is not None else None
|
| 27 |
+
if seq_idx is not None:
|
| 28 |
+
assert (
|
| 29 |
+
initial_states is None
|
| 30 |
+
), "initial_states must be None if seq_idx is not None"
|
| 31 |
+
assert (
|
| 32 |
+
not return_final_states
|
| 33 |
+
), "If seq_idx is not None, we don't return final_states_out"
|
| 34 |
+
seq_idx = seq_idx.contiguous() if seq_idx is not None else None
|
| 35 |
+
if initial_states is not None and (
|
| 36 |
+
initial_states.stride(2) != 1 and initial_states.stride(1) != 1
|
| 37 |
+
):
|
| 38 |
+
initial_states = initial_states.contiguous()
|
| 39 |
+
if return_final_states:
|
| 40 |
+
assert (
|
| 41 |
+
x.stride(1) == 1
|
| 42 |
+
), "Only channel-last layout support returning final_states_out"
|
| 43 |
+
if final_states_out is not None:
|
| 44 |
+
assert (
|
| 45 |
+
final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
|
| 46 |
+
)
|
| 47 |
+
else:
|
| 48 |
+
batch, dim, seqlen = x.shape
|
| 49 |
+
width = weight.shape[1]
|
| 50 |
+
final_states_out = torch.empty(
|
| 51 |
+
batch, width - 1, dim, device=x.device, dtype=x.dtype
|
| 52 |
+
).transpose(1, 2)
|
| 53 |
+
else:
|
| 54 |
+
final_states_out = None
|
| 55 |
+
ctx.activation = activation in ["silu", "swish"]
|
| 56 |
+
out = causal_conv1d_fwd_function(
|
| 57 |
+
x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
|
| 58 |
+
)
|
| 59 |
+
ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
|
| 60 |
+
ctx.return_final_states = return_final_states
|
| 61 |
+
ctx.return_dinitial_states = (
|
| 62 |
+
initial_states is not None and initial_states.requires_grad
|
| 63 |
+
)
|
| 64 |
+
return out if not return_final_states else (out, final_states_out)
|
| 65 |
+
|
| 66 |
+
@staticmethod
|
| 67 |
+
def backward(ctx, dout, *args):
|
| 68 |
+
x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
|
| 69 |
+
dfinal_states = args[0] if ctx.return_final_states else None
|
| 70 |
+
if dout.stride(2) != 1 and dout.stride(1) != 1:
|
| 71 |
+
dout = dout.contiguous()
|
| 72 |
+
# The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
|
| 73 |
+
# backward of conv1d with the backward of chunk).
|
| 74 |
+
# Here we just pass in None and dx will be allocated in the C++ code.
|
| 75 |
+
dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
|
| 76 |
+
x,
|
| 77 |
+
weight,
|
| 78 |
+
bias,
|
| 79 |
+
dout,
|
| 80 |
+
seq_idx,
|
| 81 |
+
initial_states,
|
| 82 |
+
dfinal_states,
|
| 83 |
+
None,
|
| 84 |
+
ctx.return_dinitial_states,
|
| 85 |
+
ctx.activation,
|
| 86 |
+
)
|
| 87 |
+
return (
|
| 88 |
+
dx,
|
| 89 |
+
dweight,
|
| 90 |
+
dbias if bias is not None else None,
|
| 91 |
+
None,
|
| 92 |
+
dinitial_states if initial_states is not None else None,
|
| 93 |
+
None,
|
| 94 |
+
None,
|
| 95 |
+
None,
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def causal_conv1d_fn(
|
| 100 |
+
x,
|
| 101 |
+
weight,
|
| 102 |
+
bias=None,
|
| 103 |
+
seq_idx=None,
|
| 104 |
+
initial_states=None,
|
| 105 |
+
return_final_states=False,
|
| 106 |
+
final_states_out=None,
|
| 107 |
+
activation=None,
|
| 108 |
+
):
|
| 109 |
+
"""
|
| 110 |
+
x: (batch, dim, seqlen)
|
| 111 |
+
weight: (dim, width)
|
| 112 |
+
bias: (dim,)
|
| 113 |
+
seq_idx: (batch, seqlen)
|
| 114 |
+
initial_states: (batch, dim, width - 1)
|
| 115 |
+
final_states_out: (batch, dim, width - 1), to be written to
|
| 116 |
+
activation: either None or "silu" or "swish"
|
| 117 |
+
|
| 118 |
+
out: (batch, dim, seqlen)
|
| 119 |
+
"""
|
| 120 |
+
return CausalConv1dFn.apply(
|
| 121 |
+
x,
|
| 122 |
+
weight,
|
| 123 |
+
bias,
|
| 124 |
+
seq_idx,
|
| 125 |
+
initial_states,
|
| 126 |
+
return_final_states,
|
| 127 |
+
final_states_out,
|
| 128 |
+
activation,
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def causal_conv1d_ref(
|
| 133 |
+
x,
|
| 134 |
+
weight,
|
| 135 |
+
bias=None,
|
| 136 |
+
initial_states=None,
|
| 137 |
+
return_final_states=False,
|
| 138 |
+
final_states_out=None,
|
| 139 |
+
activation=None,
|
| 140 |
+
):
|
| 141 |
+
"""
|
| 142 |
+
x: (batch, dim, seqlen)
|
| 143 |
+
weight: (dim, width)
|
| 144 |
+
bias: (dim,)
|
| 145 |
+
initial_states: (batch, dim, width - 1)
|
| 146 |
+
final_states_out: (batch, dim, width - 1)
|
| 147 |
+
|
| 148 |
+
out: (batch, dim, seqlen)
|
| 149 |
+
"""
|
| 150 |
+
if activation not in [None, "silu", "swish"]:
|
| 151 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 152 |
+
dtype_in = x.dtype
|
| 153 |
+
x = x.to(weight.dtype)
|
| 154 |
+
seqlen = x.shape[-1]
|
| 155 |
+
dim, width = weight.shape
|
| 156 |
+
if initial_states is None:
|
| 157 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
|
| 158 |
+
else:
|
| 159 |
+
x = torch.cat([initial_states, x], dim=-1)
|
| 160 |
+
out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
|
| 161 |
+
out = out[..., :seqlen]
|
| 162 |
+
if return_final_states:
|
| 163 |
+
final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
|
| 164 |
+
dtype_in
|
| 165 |
+
) # (batch, dim, width - 1)
|
| 166 |
+
if final_states_out is not None:
|
| 167 |
+
final_states_out.copy_(final_states)
|
| 168 |
+
else:
|
| 169 |
+
final_states_out = final_states
|
| 170 |
+
out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
| 171 |
+
return out if not return_final_states else (out, final_states_out)
|
| 172 |
+
|
| 173 |
+
|
| 174 |
+
def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
|
| 175 |
+
"""
|
| 176 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
| 177 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 178 |
+
weight: (dim, width)
|
| 179 |
+
bias: (dim,)
|
| 180 |
+
cache_seqlens: (batch,), dtype int32.
|
| 181 |
+
If not None, the conv_state is treated as a circular buffer.
|
| 182 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 183 |
+
@cache_seqlens % state_len.
|
| 184 |
+
conv_state_indices: (batch,), dtype int32
|
| 185 |
+
If None, the conv_state is a larger tensor along the batch dim,
|
| 186 |
+
and we are selecting the batch coords specified by conv_state_indices.
|
| 187 |
+
Useful for a continuous batching scenario.
|
| 188 |
+
|
| 189 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
| 190 |
+
"""
|
| 191 |
+
if activation not in [None, "silu", "swish"]:
|
| 192 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 193 |
+
activation = activation in ["silu", "swish"]
|
| 194 |
+
unsqueeze = x.dim() == 2
|
| 195 |
+
if unsqueeze:
|
| 196 |
+
x = x.unsqueeze(-1)
|
| 197 |
+
out = causal_conv1d_update_function(
|
| 198 |
+
x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
|
| 199 |
+
)
|
| 200 |
+
if unsqueeze:
|
| 201 |
+
out = out.squeeze(-1)
|
| 202 |
+
return out
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
|
| 206 |
+
"""
|
| 207 |
+
x: (batch, dim) or (batch, dim, seqlen)
|
| 208 |
+
conv_state: (batch, dim, state_len), where state_len >= width - 1
|
| 209 |
+
weight: (dim, width)
|
| 210 |
+
bias: (dim,)
|
| 211 |
+
cache_seqlens: (batch,), dtype int32.
|
| 212 |
+
If not None, the conv_state is treated as a circular buffer.
|
| 213 |
+
The conv_state will be updated by copying x to the conv_state starting at the index
|
| 214 |
+
@cache_seqlens % state_len before performing the convolution.
|
| 215 |
+
|
| 216 |
+
out: (batch, dim) or (batch, dim, seqlen)
|
| 217 |
+
"""
|
| 218 |
+
if activation not in [None, "silu", "swish"]:
|
| 219 |
+
raise NotImplementedError("activation must be None, silu, or swish")
|
| 220 |
+
dtype_in = x.dtype
|
| 221 |
+
unsqueeze = x.dim() == 2
|
| 222 |
+
if unsqueeze:
|
| 223 |
+
x = x.unsqueeze(-1)
|
| 224 |
+
batch, dim, seqlen = x.shape
|
| 225 |
+
width = weight.shape[1]
|
| 226 |
+
state_len = conv_state.shape[-1]
|
| 227 |
+
assert conv_state.shape == (batch, dim, state_len)
|
| 228 |
+
assert weight.shape == (dim, width)
|
| 229 |
+
if cache_seqlens is None:
|
| 230 |
+
x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
|
| 231 |
+
conv_state.copy_(x_new[:, :, -state_len:])
|
| 232 |
+
else:
|
| 233 |
+
width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 234 |
+
width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 235 |
+
x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
|
| 236 |
+
copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
|
| 237 |
+
copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
|
| 238 |
+
conv_state.scatter_(2, copy_idx, x)
|
| 239 |
+
out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
|
| 240 |
+
if unsqueeze:
|
| 241 |
+
out = out.squeeze(-1)
|
| 242 |
+
return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
|
torch-ext/causal_conv1d/causal_conv1d_varlen.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
|
| 4 |
+
import triton
|
| 5 |
+
import triton.language as tl
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
@triton.jit
|
| 9 |
+
def _causal_conv1d_varlen_states(
|
| 10 |
+
X,
|
| 11 |
+
CU_SEQLENS,
|
| 12 |
+
STATES,
|
| 13 |
+
state_len,
|
| 14 |
+
dim,
|
| 15 |
+
stride_x_seqlen, stride_x_dim,
|
| 16 |
+
stride_states_batch, stride_states_seqlen, stride_states_dim,
|
| 17 |
+
BLOCK_M: tl.constexpr,
|
| 18 |
+
BLOCK_N: tl.constexpr
|
| 19 |
+
):
|
| 20 |
+
batch_idx = tl.program_id(2)
|
| 21 |
+
STATES += batch_idx * stride_states_batch
|
| 22 |
+
end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
|
| 23 |
+
start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
|
| 24 |
+
rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 25 |
+
cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
|
| 26 |
+
x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
|
| 27 |
+
mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
|
| 28 |
+
other=0)
|
| 29 |
+
rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
|
| 30 |
+
tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
|
| 31 |
+
x,
|
| 32 |
+
mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 36 |
+
"""
|
| 37 |
+
Forward pass only, does not support backward pass.
|
| 38 |
+
Parameters:
|
| 39 |
+
x: (total_tokens, dim)
|
| 40 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 41 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 42 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 43 |
+
Return:
|
| 44 |
+
states: (batch, dim, state_len)
|
| 45 |
+
"""
|
| 46 |
+
_, dim = x.shape
|
| 47 |
+
batch = cu_seqlens.shape[0] - 1
|
| 48 |
+
cu_seqlens = cu_seqlens.contiguous()
|
| 49 |
+
states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 50 |
+
BLOCK_M = min(triton.next_power_of_2(state_len), 16)
|
| 51 |
+
BLOCK_N = min(triton.next_power_of_2(dim), 256)
|
| 52 |
+
grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
|
| 53 |
+
with torch.cuda.device(x.device.index):
|
| 54 |
+
_causal_conv1d_varlen_states[grid](
|
| 55 |
+
x,
|
| 56 |
+
cu_seqlens,
|
| 57 |
+
states,
|
| 58 |
+
state_len,
|
| 59 |
+
dim,
|
| 60 |
+
x.stride(0), x.stride(1),
|
| 61 |
+
states.stride(0), states.stride(2), states.stride(1),
|
| 62 |
+
BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
|
| 63 |
+
)
|
| 64 |
+
return states
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
|
| 68 |
+
"""
|
| 69 |
+
Forward pass only, does not support backward pass.
|
| 70 |
+
Parameters:
|
| 71 |
+
x: (total_tokens, dim)
|
| 72 |
+
cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
|
| 73 |
+
state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
|
| 74 |
+
If some of those elements belong to a different sequence, the value of the states will be zero.
|
| 75 |
+
Return:
|
| 76 |
+
states: (batch, dim, state_len)
|
| 77 |
+
"""
|
| 78 |
+
_, dim = x.shape
|
| 79 |
+
batch = cu_seqlens.shape[0] - 1
|
| 80 |
+
cu_seqlens = cu_seqlens.contiguous()
|
| 81 |
+
states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
|
| 82 |
+
for i in range(batch):
|
| 83 |
+
end_idx = cu_seqlens[i + 1]
|
| 84 |
+
start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
|
| 85 |
+
states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
|
| 86 |
+
return states
|
torch-ext/causal_conv1d/cpp_functions.py
ADDED
|
@@ -0,0 +1,96 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) 2024, Tri Dao.
|
| 2 |
+
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
from ._ops import ops
|
| 6 |
+
|
| 7 |
+
def causal_conv1d_fwd_function(
|
| 8 |
+
x: torch.Tensor,
|
| 9 |
+
weight: torch.Tensor,
|
| 10 |
+
bias: torch.Tensor | None,
|
| 11 |
+
seq_idx: torch.Tensor | None,
|
| 12 |
+
initial_states: torch.Tensor | None,
|
| 13 |
+
final_states_out: torch.Tensor | None,
|
| 14 |
+
silu_activation: bool,
|
| 15 |
+
) -> torch.Tensor:
|
| 16 |
+
out = torch.empty_like(x)
|
| 17 |
+
ops.causal_conv1d_fwd(
|
| 18 |
+
x=x,
|
| 19 |
+
weight=weight,
|
| 20 |
+
bias=bias,
|
| 21 |
+
seq_idx=seq_idx,
|
| 22 |
+
initial_states=initial_states,
|
| 23 |
+
out=out,
|
| 24 |
+
final_states_out=final_states_out,
|
| 25 |
+
silu_activation=silu_activation,
|
| 26 |
+
)
|
| 27 |
+
return out
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def causal_conv1d_bwd_function(
|
| 31 |
+
x: torch.Tensor,
|
| 32 |
+
weight: torch.Tensor,
|
| 33 |
+
bias: torch.Tensor | None,
|
| 34 |
+
dout: torch.Tensor,
|
| 35 |
+
seq_idx: torch.Tensor | None,
|
| 36 |
+
initial_states: torch.Tensor | None,
|
| 37 |
+
dfinal_states: torch.Tensor | None,
|
| 38 |
+
dx: torch.Tensor | None,
|
| 39 |
+
return_dinitial_states: torch.Tensor,
|
| 40 |
+
silu_activation: bool,
|
| 41 |
+
) -> tuple[torch.Tensor | None]:
|
| 42 |
+
batch_size, dim = x.size()[:2]
|
| 43 |
+
width = weight.size(-1)
|
| 44 |
+
|
| 45 |
+
if dx is None:
|
| 46 |
+
dx = torch.empty_like(x)
|
| 47 |
+
dweight = torch.zeros_like(weight, dtype=torch.float32)
|
| 48 |
+
dbias = None
|
| 49 |
+
if bias is not None:
|
| 50 |
+
dbias = torch.zeros_like(bias, dtype=torch.float32)
|
| 51 |
+
dinitial_states = None
|
| 52 |
+
if return_dinitial_states:
|
| 53 |
+
dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
|
| 54 |
+
|
| 55 |
+
ops.causal_conv1d_bwd(
|
| 56 |
+
x=x,
|
| 57 |
+
weight=weight,
|
| 58 |
+
bias=bias,
|
| 59 |
+
dout=dout,
|
| 60 |
+
seq_idx=seq_idx,
|
| 61 |
+
initial_states=initial_states,
|
| 62 |
+
dfinal_states=dfinal_states,
|
| 63 |
+
dx=dx,
|
| 64 |
+
dweight=dweight,
|
| 65 |
+
dbias=dbias,
|
| 66 |
+
dinitial_states=dinitial_states,
|
| 67 |
+
silu_activation=silu_activation,
|
| 68 |
+
)
|
| 69 |
+
|
| 70 |
+
dweight = dweight.type_as(weight)
|
| 71 |
+
if dbias is not None:
|
| 72 |
+
dbias = dbias.type_as(bias)
|
| 73 |
+
return dx, dweight, dbias, dinitial_states
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def causal_conv1d_update_function(
|
| 77 |
+
x: torch.Tensor,
|
| 78 |
+
conv_state: torch.Tensor,
|
| 79 |
+
weight: torch.Tensor,
|
| 80 |
+
bias: torch.Tensor | None,
|
| 81 |
+
silu_activation: bool,
|
| 82 |
+
cache_seqlens: torch.Tensor | None,
|
| 83 |
+
conv_state_indices: torch.Tensor | None,
|
| 84 |
+
) -> torch.Tensor:
|
| 85 |
+
out = torch.empty_like(x)
|
| 86 |
+
ops.causal_conv1d_update(
|
| 87 |
+
x=x,
|
| 88 |
+
conv_state=conv_state,
|
| 89 |
+
weight=weight,
|
| 90 |
+
bias=bias,
|
| 91 |
+
out=out,
|
| 92 |
+
silu_activation=silu_activation,
|
| 93 |
+
cache_seqlens=cache_seqlens,
|
| 94 |
+
conv_state_indices=conv_state_indices,
|
| 95 |
+
)
|
| 96 |
+
return out
|
torch-ext/pytorch_shim.h
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/library.h>
|
| 4 |
+
|
| 5 |
+
/**
|
| 6 |
+
* Unforunately, the type signatures of the flash_attn ops are not compatible
|
| 7 |
+
* with the PyTorch library bindings. To get around that we use
|
| 8 |
+
* `make_pytorch_shim` which creates a lambda that exponses the API using
|
| 9 |
+
* PyTorch compatible types to the types, then converts them to the types
|
| 10 |
+
* expected by the flash_attn ops. This shims allows us to make minimal changes
|
| 11 |
+
* to `flash_api.cpp` making it easier to synchronize with upstream changes.
|
| 12 |
+
*
|
| 13 |
+
* The `pytorch_library_compatible_type` struct is used to map from the
|
| 14 |
+
* flash_attn ops types to a PyTorch library compatible one. The main issues is
|
| 15 |
+
* that the following types are not support by PyTorch libary bindings:
|
| 16 |
+
* - `int`
|
| 17 |
+
* - `float`
|
| 18 |
+
* - `std::optional<T> &`
|
| 19 |
+
* - `std::optional<const at::Tensor> &`
|
| 20 |
+
* So we convert them to (respectively):
|
| 21 |
+
* - `int64_t`
|
| 22 |
+
* - `double`
|
| 23 |
+
* - `const std::optional<T>&`
|
| 24 |
+
* - `const std::optional<at::Tensor>&`
|
| 25 |
+
*/
|
| 26 |
+
|
| 27 |
+
template<typename T>
|
| 28 |
+
struct pytorch_library_compatible_type {
|
| 29 |
+
using type = T;
|
| 30 |
+
static T convert_from_type(T arg) { return arg; }
|
| 31 |
+
};
|
| 32 |
+
|
| 33 |
+
template<typename T>
|
| 34 |
+
using pytorch_library_compatible_type_t = \
|
| 35 |
+
typename pytorch_library_compatible_type<T>::type;
|
| 36 |
+
|
| 37 |
+
template<typename T>
|
| 38 |
+
T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg)
|
| 39 |
+
{ return pytorch_library_compatible_type<T>::convert_from_type(arg); }
|
| 40 |
+
|
| 41 |
+
// Map `std::optional<T> &` -> `const std::optional<T>&`
|
| 42 |
+
// (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
|
| 43 |
+
// the optional container)
|
| 44 |
+
template<typename T>
|
| 45 |
+
struct pytorch_library_compatible_type<std::optional<T> &> {
|
| 46 |
+
using type = const std::optional<T>&;
|
| 47 |
+
static std::optional<T>& convert_from_type(const std::optional<T> &arg) {
|
| 48 |
+
return const_cast<std::optional<T>&>(arg);
|
| 49 |
+
}
|
| 50 |
+
};
|
| 51 |
+
|
| 52 |
+
// Map `std::optional<T>` ->
|
| 53 |
+
// `std::optional<pytorch_library_compatible_type_t<T>>`
|
| 54 |
+
// (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
|
| 55 |
+
template<typename T>
|
| 56 |
+
struct pytorch_library_compatible_type<std::optional<T>> {
|
| 57 |
+
using type = std::optional<pytorch_library_compatible_type_t<T>>;
|
| 58 |
+
static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(std::optional<T> arg) {
|
| 59 |
+
return arg;
|
| 60 |
+
}
|
| 61 |
+
};
|
| 62 |
+
|
| 63 |
+
// Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
|
| 64 |
+
template<>
|
| 65 |
+
struct pytorch_library_compatible_type<std::optional<const at::Tensor> &> {
|
| 66 |
+
using type = const std::optional<at::Tensor>&;
|
| 67 |
+
static std::optional<const at::Tensor>& convert_from_type(
|
| 68 |
+
const std::optional<at::Tensor> &arg) {
|
| 69 |
+
return const_cast<std::optional<const at::Tensor>&>(
|
| 70 |
+
reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
|
| 71 |
+
}
|
| 72 |
+
};
|
| 73 |
+
|
| 74 |
+
// Map `int` -> `int64_t`
|
| 75 |
+
template<> struct pytorch_library_compatible_type<int> {
|
| 76 |
+
using type = int64_t;
|
| 77 |
+
static int convert_from_type(int64_t arg) {
|
| 78 |
+
TORCH_CHECK(arg <= std::numeric_limits<int>::max(),
|
| 79 |
+
"int64_t value is too large to be converted to int");
|
| 80 |
+
TORCH_CHECK(arg >= std::numeric_limits<int>::min(),
|
| 81 |
+
"int64_t value is too small to be converted to int");
|
| 82 |
+
return arg;
|
| 83 |
+
}
|
| 84 |
+
};
|
| 85 |
+
|
| 86 |
+
// Map `float` -> `double`
|
| 87 |
+
template<> struct pytorch_library_compatible_type<float> {
|
| 88 |
+
using type = double;
|
| 89 |
+
static float convert_from_type(double arg) {
|
| 90 |
+
TORCH_CHECK(std::abs(arg) <= std::numeric_limits<float>::max(),
|
| 91 |
+
"double value is too large to be converted to float");
|
| 92 |
+
return arg;
|
| 93 |
+
}
|
| 94 |
+
};
|
| 95 |
+
|
| 96 |
+
//
|
| 97 |
+
// Shim Utils
|
| 98 |
+
//
|
| 99 |
+
|
| 100 |
+
template <typename Ret, typename... Args>
|
| 101 |
+
auto make_pytorch_shim(Ret(*fun)(Args... args)){
|
| 102 |
+
return [fun](pytorch_library_compatible_type_t<Args>... args) {
|
| 103 |
+
return fun(convert_from_pytorch_compatible_type<Args>(args)...);
|
| 104 |
+
};
|
| 105 |
+
}
|
torch-ext/torch_binding.cpp
ADDED
|
@@ -0,0 +1,32 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#include <torch/library.h>
|
| 2 |
+
|
| 3 |
+
#include "registration.h"
|
| 4 |
+
|
| 5 |
+
#include "pytorch_shim.h"
|
| 6 |
+
#include "torch_binding.h"
|
| 7 |
+
|
| 8 |
+
TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
|
| 9 |
+
ops.def(
|
| 10 |
+
"causal_conv1d_fwd("
|
| 11 |
+
" Tensor x, Tensor weight, Tensor? bias, Tensor? seq_idx,"
|
| 12 |
+
" Tensor? initial_states, Tensor! out, Tensor!? final_states_out,"
|
| 13 |
+
" bool silu_activation) -> ()");
|
| 14 |
+
ops.impl("causal_conv1d_fwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_fwd));
|
| 15 |
+
|
| 16 |
+
ops.def(
|
| 17 |
+
"causal_conv1d_bwd("
|
| 18 |
+
" Tensor x, Tensor weight, Tensor? bias, Tensor! dout,"
|
| 19 |
+
" Tensor? seq_idx, Tensor? initial_states, Tensor? dfinal_states,"
|
| 20 |
+
" Tensor! dx, Tensor! dweight, Tensor!? dbias,"
|
| 21 |
+
" Tensor!? dinitial_states, bool silu_activation) -> ()");
|
| 22 |
+
ops.impl("causal_conv1d_bwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_bwd));
|
| 23 |
+
|
| 24 |
+
ops.def(
|
| 25 |
+
"causal_conv1d_update("
|
| 26 |
+
" Tensor x, Tensor conv_state, Tensor weight, Tensor? bias,"
|
| 27 |
+
" Tensor! out, bool silu_activation, Tensor? cache_seqlens,"
|
| 28 |
+
" Tensor? conv_state_indices) -> ()");
|
| 29 |
+
ops.impl("causal_conv1d_update", torch::kCUDA, make_pytorch_shim(&causal_conv1d_update));
|
| 30 |
+
}
|
| 31 |
+
|
| 32 |
+
REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
|
torch-ext/torch_binding.h
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#pragma once
|
| 2 |
+
|
| 3 |
+
#include <torch/torch.h>
|
| 4 |
+
|
| 5 |
+
void
|
| 6 |
+
causal_conv1d_fwd(const at::Tensor &x,
|
| 7 |
+
const at::Tensor &weight,
|
| 8 |
+
const c10::optional<at::Tensor> &bias_,
|
| 9 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
| 10 |
+
const c10::optional<at::Tensor> &initial_states_,
|
| 11 |
+
at::Tensor &out,
|
| 12 |
+
c10::optional<at::Tensor> &final_states_out_,
|
| 13 |
+
bool silu_activation);
|
| 14 |
+
|
| 15 |
+
void
|
| 16 |
+
causal_conv1d_bwd(const at::Tensor &x,
|
| 17 |
+
const at::Tensor &weight,
|
| 18 |
+
const c10::optional<at::Tensor> &bias_,
|
| 19 |
+
at::Tensor &dout,
|
| 20 |
+
const c10::optional<at::Tensor> &seq_idx_,
|
| 21 |
+
const c10::optional<at::Tensor> &initial_states_,
|
| 22 |
+
const c10::optional<at::Tensor> &dfinal_states_,
|
| 23 |
+
at::Tensor &dx,
|
| 24 |
+
at::Tensor &dweight,
|
| 25 |
+
c10::optional<at::Tensor> &dbias_,
|
| 26 |
+
c10::optional<at::Tensor> &dinitial_states_,
|
| 27 |
+
bool silu_activation);
|
| 28 |
+
|
| 29 |
+
void
|
| 30 |
+
causal_conv1d_update(const at::Tensor &x,
|
| 31 |
+
const at::Tensor &conv_state,
|
| 32 |
+
const at::Tensor &weight,
|
| 33 |
+
const c10::optional<at::Tensor> &bias_,
|
| 34 |
+
at::Tensor &out,
|
| 35 |
+
bool silu_activation,
|
| 36 |
+
const c10::optional<at::Tensor> &cache_seqlens_,
|
| 37 |
+
const c10::optional<at::Tensor> &conv_state_indices_
|
| 38 |
+
);
|
| 39 |
+
|