Kernels
danieldk HF Staff commited on
Commit
db03e28
·
0 Parent(s):

Convert causal-conv1d to a Hub kernel

Browse files
README.md ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ tags:
4
+ - kernel
5
+ ---
6
+
7
+ ## causal-conv1d
8
+
9
+ Causal depthwise conv1d kernel by Tri Dao. Source: https://github.com/Dao-AILab/causal-conv1d/
10
+
build.toml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [general]
2
+ name = "causal_conv1d"
3
+ universal = false
4
+
5
+ [torch]
6
+ src = [
7
+ "torch-ext/pytorch_shim.h",
8
+ "torch-ext/torch_binding.cpp",
9
+ "torch-ext/torch_binding.h"
10
+ ]
11
+
12
+ [kernel.causal_conv1d]
13
+ backend = "cuda"
14
+ src = [
15
+ "causal-conv1d/causal_conv1d_bwd.cu",
16
+ "causal-conv1d/causal_conv1d_common.h",
17
+ "causal-conv1d/causal_conv1d.cpp",
18
+ "causal-conv1d/causal_conv1d_fwd.cu",
19
+ "causal-conv1d/causal_conv1d.h",
20
+ "causal-conv1d/causal_conv1d_update.cu",
21
+ "causal-conv1d/static_switch.h",
22
+ ]
23
+ include = [ "causal-conv1d" ]
24
+ depends = [ "torch" ]
25
+
26
+ [kernel.causal_conv1d_rocm]
27
+ backend = "rocm"
28
+ rocm-archs = [
29
+ "gfx906",
30
+ "gfx908",
31
+ "gfx90a",
32
+ "gfx940",
33
+ "gfx941",
34
+ "gfx942",
35
+ "gfx1030",
36
+ "gfx1100",
37
+ "gfx1101",
38
+ ]
39
+ src = [
40
+ "causal-conv1d/causal_conv1d_bwd.cu",
41
+ "causal-conv1d/causal_conv1d_common.h",
42
+ "causal-conv1d/causal_conv1d.cpp",
43
+ "causal-conv1d/causal_conv1d_fwd.cu",
44
+ "causal-conv1d/causal_conv1d.h",
45
+ "causal-conv1d/causal_conv1d_update.cu",
46
+ "causal-conv1d/static_switch.h",
47
+ ]
48
+ include = [ "causal-conv1d" ]
49
+ depends = [ "torch" ]
causal-conv1d/causal_conv1d.cpp ADDED
@@ -0,0 +1,486 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <torch/all.h>
6
+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
7
+ #include <c10/core/DeviceGuard.h>
8
+ #else
9
+ #include <c10/cuda/CUDAGuard.h>
10
+ #endif
11
+
12
+ #include <c10/cuda/CUDAStream.h>
13
+ #include <vector>
14
+
15
+ #include "causal_conv1d.h"
16
+
17
+ #define CHECK_SHAPE(x, ...) TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
18
+
19
+ #define DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(ITYPE, NAME, ...) \
20
+ if (ITYPE == at::ScalarType::Half) { \
21
+ using input_t = at::Half; \
22
+ __VA_ARGS__(); \
23
+ } else if (ITYPE == at::ScalarType::BFloat16) { \
24
+ using input_t = at::BFloat16; \
25
+ __VA_ARGS__(); \
26
+ } else if (ITYPE == at::ScalarType::Float) { \
27
+ using input_t = float; \
28
+ __VA_ARGS__(); \
29
+ } else { \
30
+ AT_ERROR(#NAME, " not implemented for input type '", toString(ITYPE), "'"); \
31
+ }
32
+
33
+ #define DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(WTYPE, NAME, ...) \
34
+ if (WTYPE == at::ScalarType::Half) { \
35
+ using weight_t = at::Half; \
36
+ __VA_ARGS__(); \
37
+ } else if (WTYPE == at::ScalarType::BFloat16) { \
38
+ using weight_t = at::BFloat16; \
39
+ __VA_ARGS__(); \
40
+ } else if (WTYPE == at::ScalarType::Float) { \
41
+ using weight_t = float; \
42
+ __VA_ARGS__(); \
43
+ } else { \
44
+ AT_ERROR(#NAME, " not implemented for weight type '", toString(WTYPE), "'"); \
45
+ }
46
+
47
+ template<typename input_t, typename weight_t>
48
+ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
49
+ template <typename input_t, typename weight_t>
50
+ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream);
51
+
52
+ template<typename input_t, typename weight_t>
53
+ void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
54
+ template<typename input_t, typename weight_t>
55
+ void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream);
56
+
57
+ template<typename input_t, typename weight_t>
58
+ void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream);
59
+
60
+ void set_conv_params_fwd(ConvParamsBase &params,
61
+ // sizes
62
+ const size_t batch,
63
+ const size_t dim,
64
+ const size_t seqlen,
65
+ const size_t width,
66
+ // device pointers
67
+ const at::Tensor x,
68
+ const at::Tensor weight,
69
+ const at::Tensor out,
70
+ void* bias_ptr,
71
+ bool silu_activation) {
72
+
73
+ // Reset the parameters
74
+ memset(&params, 0, sizeof(params));
75
+
76
+ params.batch = batch;
77
+ params.dim = dim;
78
+ params.seqlen = seqlen;
79
+ params.width = width;
80
+
81
+ params.silu_activation = silu_activation;
82
+
83
+ // Set the pointers and strides.
84
+ params.x_ptr = x.data_ptr();
85
+ params.weight_ptr = weight.data_ptr();
86
+ params.bias_ptr = bias_ptr;
87
+ params.out_ptr = out.data_ptr();
88
+ // All stride are in elements, not bytes.
89
+ params.x_batch_stride = x.stride(0);
90
+ params.x_c_stride = x.stride(1);
91
+ params.x_l_stride = x.stride(-1);
92
+ params.weight_c_stride = weight.stride(0);
93
+ params.weight_width_stride = weight.stride(1);
94
+ params.out_batch_stride = out.stride(0);
95
+ params.out_c_stride = out.stride(1);
96
+ params.out_l_stride = out.stride(-1);
97
+ }
98
+
99
+
100
+ void set_conv_params_bwd(ConvParamsBwd &params,
101
+ // sizes
102
+ const size_t batch,
103
+ const size_t dim,
104
+ const size_t seqlen,
105
+ const size_t width,
106
+ // device pointers
107
+ const at::Tensor x,
108
+ const at::Tensor weight,
109
+ void* bias_ptr,
110
+ const at::Tensor dout,
111
+ const at::Tensor dx,
112
+ const at::Tensor dweight,
113
+ void* dbias_ptr,
114
+ bool silu_activation) {
115
+ // Pass in "dout" instead of "out", we're not gonna use "out" at all.
116
+ set_conv_params_fwd(params, batch, dim, seqlen, width,
117
+ x, weight, dout, bias_ptr, silu_activation);
118
+
119
+ // Set the pointers and strides.
120
+ params.dout_ptr = dout.data_ptr();
121
+ params.dx_ptr = dx.data_ptr();
122
+ params.dweight_ptr = dweight.data_ptr();
123
+ params.dbias_ptr = dbias_ptr;
124
+ // All stride are in elements, not bytes.
125
+ params.dout_batch_stride = dout.stride(0);
126
+ params.dout_c_stride = dout.stride(1);
127
+ params.dout_l_stride = dout.stride(2);
128
+ params.dweight_c_stride = dweight.stride(0);
129
+ params.dweight_width_stride = dweight.stride(1);
130
+ params.dx_batch_stride = dx.stride(0);
131
+ params.dx_c_stride = dx.stride(1);
132
+ params.dx_l_stride = dx.stride(2);
133
+ }
134
+
135
+ void
136
+ causal_conv1d_fwd(const at::Tensor &x,
137
+ const at::Tensor &weight,
138
+ const c10::optional<at::Tensor> &bias_,
139
+ const c10::optional<at::Tensor> &seq_idx_,
140
+ const c10::optional<at::Tensor> &initial_states_,
141
+ at::Tensor &out,
142
+ c10::optional<at::Tensor> &final_states_out_,
143
+ bool silu_activation) {
144
+ auto input_type = x.scalar_type();
145
+ auto weight_type = weight.scalar_type();
146
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
147
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
148
+
149
+ TORCH_CHECK(x.is_cuda());
150
+ TORCH_CHECK(weight.is_cuda());
151
+
152
+ const auto sizes = x.sizes();
153
+ const int batch_size = sizes[0];
154
+ const int dim = sizes[1];
155
+ const int seqlen = sizes[2];
156
+ const int width = weight.size(-1);
157
+
158
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
159
+ CHECK_SHAPE(weight, dim, width);
160
+
161
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
162
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
163
+
164
+ if (is_channel_last) {
165
+ TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
166
+ TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
167
+ }
168
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
169
+
170
+ if (bias_.has_value()) {
171
+ auto bias = bias_.value();
172
+ TORCH_CHECK(bias.scalar_type() == weight_type);
173
+ TORCH_CHECK(bias.is_cuda());
174
+ TORCH_CHECK(bias.stride(-1) == 1);
175
+ CHECK_SHAPE(bias, dim);
176
+ }
177
+
178
+ if (seq_idx_.has_value()) {
179
+ TORCH_CHECK(is_channel_last, "seq_idx is only supported for channel last layout");
180
+ auto seq_idx = seq_idx_.value();
181
+ TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
182
+ TORCH_CHECK(seq_idx.is_cuda());
183
+ TORCH_CHECK(seq_idx.is_contiguous());
184
+ CHECK_SHAPE(seq_idx, batch_size, seqlen);
185
+ }
186
+
187
+ ConvParamsBase params;
188
+ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
189
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
190
+ silu_activation);
191
+
192
+ if (seq_idx_.has_value()) {
193
+ params.seq_idx_ptr = seq_idx_.value().data_ptr();
194
+ } else {
195
+ params.seq_idx_ptr = nullptr;
196
+ }
197
+
198
+ if (initial_states_.has_value()) {
199
+ TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
200
+ auto initial_states = initial_states_.value();
201
+ TORCH_CHECK(initial_states.scalar_type() == input_type);
202
+ TORCH_CHECK(initial_states.is_cuda());
203
+ CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
204
+ TORCH_CHECK(initial_states.stride(1) == 1);
205
+ params.initial_states_ptr = initial_states.data_ptr();
206
+ params.initial_states_batch_stride = initial_states.stride(0);
207
+ params.initial_states_c_stride = initial_states.stride(1);
208
+ params.initial_states_l_stride = initial_states.stride(2);
209
+ } else {
210
+ params.initial_states_ptr = nullptr;
211
+ }
212
+
213
+ if (final_states_out_.has_value()) {
214
+ TORCH_CHECK(is_channel_last, "final_states is only supported for channel last layout");
215
+ auto final_states = final_states_out_.value();
216
+ TORCH_CHECK(final_states.scalar_type() == input_type);
217
+ TORCH_CHECK(final_states.is_cuda());
218
+ CHECK_SHAPE(final_states, batch_size, dim, width - 1);
219
+ TORCH_CHECK(final_states.stride(1) == 1);
220
+ params.final_states_ptr = final_states.data_ptr();
221
+ params.final_states_batch_stride = final_states.stride(0);
222
+ params.final_states_c_stride = final_states.stride(1);
223
+ params.final_states_l_stride = final_states.stride(2);
224
+ } else {
225
+ params.final_states_ptr = nullptr;
226
+ }
227
+
228
+ // Otherwise the kernel will be launched from cuda:0 device
229
+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
230
+ c10::DeviceGuard device_guard(x.device());
231
+ #else
232
+ at::cuda::CUDAGuard device_guard{x.device()};
233
+ #endif
234
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
235
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_fwd", [&] {
236
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_fwd", [&] {
237
+ if (!is_channel_last) {
238
+ causal_conv1d_fwd_cuda<input_t, weight_t>(params, stream);
239
+ } else {
240
+ causal_conv1d_channellast_fwd_cuda<input_t, weight_t>(params, stream);
241
+ }
242
+ });
243
+ });
244
+ }
245
+
246
+ void
247
+ causal_conv1d_bwd(const at::Tensor &x,
248
+ const at::Tensor &weight,
249
+ const c10::optional<at::Tensor> &bias_,
250
+ at::Tensor &dout,
251
+ const c10::optional<at::Tensor> &seq_idx_,
252
+ const c10::optional<at::Tensor> &initial_states_,
253
+ const c10::optional<at::Tensor> &dfinal_states_,
254
+ at::Tensor &dx,
255
+ at::Tensor &dweight,
256
+ c10::optional<at::Tensor> &dbias_,
257
+ c10::optional<at::Tensor> &dinitial_states_,
258
+ bool silu_activation) {
259
+ auto input_type = x.scalar_type();
260
+ auto weight_type = weight.scalar_type();
261
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
262
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
263
+
264
+ TORCH_CHECK(x.is_cuda());
265
+ TORCH_CHECK(weight.is_cuda());
266
+ TORCH_CHECK(dout.is_cuda());
267
+ TORCH_CHECK(bias_.has_value() == dbias_.has_value());
268
+
269
+ const auto sizes = x.sizes();
270
+ const int batch_size = sizes[0];
271
+ const int dim = sizes[1];
272
+ const int seqlen = sizes[2];
273
+ const int width = weight.size(-1);
274
+
275
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
276
+
277
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
278
+ CHECK_SHAPE(weight, dim, width);
279
+ CHECK_SHAPE(dout, batch_size, dim, seqlen);
280
+
281
+ TORCH_CHECK(x.stride(2) == 1 || x.stride(1) == 1);
282
+ const bool is_channel_last = x.stride(1) == 1 && x.stride(2) > 1;
283
+ if (!is_channel_last && dout.stride(2) != 1) { dout = dout.contiguous(); }
284
+ if (is_channel_last && dout.stride(1) != 1) { dout = dout.transpose(-1, -2).contiguous().transpose(-1, -2); }
285
+
286
+ if (is_channel_last) {
287
+ TORCH_CHECK(dim % 8 == 0, "causal_conv1d only supports channel dimension divisible by 8 for now");
288
+ TORCH_CHECK(x.stride(2) % 8 == 0 and x.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (x.stride(0) and x.stride(2)) to be multiples of 8");
289
+ TORCH_CHECK(dout.stride(2) % 8 == 0 and dout.stride(0) % 8 == 0, "causal_conv1d with channel last layout requires strides (dout.stride(0) and dout.stride(2)) to be multiples of 8");
290
+ }
291
+
292
+ if (bias_.has_value()) {
293
+ auto bias = bias_.value();
294
+ TORCH_CHECK(bias.scalar_type() == weight_type);
295
+ TORCH_CHECK(bias.is_cuda());
296
+ TORCH_CHECK(bias.stride(-1) == 1);
297
+ CHECK_SHAPE(bias, dim);
298
+ }
299
+
300
+ if (seq_idx_.has_value()) {
301
+ TORCH_CHECK(is_channel_last, "seq_idx only supported for channel last layout");
302
+ auto seq_idx = seq_idx_.value();
303
+ TORCH_CHECK(seq_idx.scalar_type() == torch::kInt32);
304
+ TORCH_CHECK(seq_idx.is_cuda());
305
+ TORCH_CHECK(seq_idx.is_contiguous());
306
+ CHECK_SHAPE(seq_idx, batch_size, seqlen);
307
+ }
308
+
309
+ TORCH_CHECK(dx.scalar_type() == input_type);
310
+ TORCH_CHECK(dx.is_cuda());
311
+ CHECK_SHAPE(dx, batch_size, dim, seqlen);
312
+ if (!is_channel_last) { TORCH_CHECK(dx.stride(2) == 1); }
313
+ if (is_channel_last) { TORCH_CHECK(dx.stride(1) == 1); }
314
+
315
+ // Otherwise the kernel will be launched from cuda:0 device
316
+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
317
+ c10::Device device = x.device();
318
+ c10::DeviceGuard device_guard(device);
319
+ #else
320
+ at::cuda::CUDAGuard device_guard{x.device()};
321
+ #endif
322
+ ConvParamsBwd params;
323
+ set_conv_params_bwd(params, batch_size, dim, seqlen, width,
324
+ x, weight, bias_.has_value() ? bias_.value().data_ptr() : nullptr,
325
+ dout, dx, dweight, bias_.has_value() ? dbias_.value().data_ptr() : nullptr,
326
+ silu_activation);
327
+
328
+ if (seq_idx_.has_value()) {
329
+ params.seq_idx_ptr = seq_idx_.value().data_ptr();
330
+ } else {
331
+ params.seq_idx_ptr = nullptr;
332
+ }
333
+
334
+ if (initial_states_.has_value()) {
335
+ TORCH_CHECK(is_channel_last, "initial_states is only supported for channel last layout");
336
+ auto initial_states = initial_states_.value();
337
+ TORCH_CHECK(initial_states.scalar_type() == input_type);
338
+ TORCH_CHECK(initial_states.is_cuda());
339
+ CHECK_SHAPE(initial_states, batch_size, dim, width - 1);
340
+ TORCH_CHECK(initial_states.stride(1) == 1);
341
+ params.initial_states_ptr = initial_states.data_ptr();
342
+ params.initial_states_batch_stride = initial_states.stride(0);
343
+ params.initial_states_c_stride = initial_states.stride(1);
344
+ params.initial_states_l_stride = initial_states.stride(2);
345
+ } else {
346
+ params.initial_states_ptr = nullptr;
347
+ }
348
+
349
+ if (dfinal_states_.has_value()) {
350
+ TORCH_CHECK(is_channel_last, "dfinal_states is only supported for channel last layout");
351
+ auto dfinal_states = dfinal_states_.value();
352
+ TORCH_CHECK(dfinal_states.scalar_type() == input_type);
353
+ TORCH_CHECK(dfinal_states.is_cuda());
354
+ CHECK_SHAPE(dfinal_states, batch_size, dim, width - 1);
355
+ params.dfinal_states_ptr = dfinal_states.data_ptr();
356
+ params.dfinal_states_batch_stride = dfinal_states.stride(0);
357
+ params.dfinal_states_c_stride = dfinal_states.stride(1);
358
+ params.dfinal_states_l_stride = dfinal_states.stride(2);
359
+ } else {
360
+ params.dfinal_states_ptr = nullptr;
361
+ }
362
+
363
+ if (dinitial_states_.has_value()) {
364
+ at::Tensor dinitial_states = dinitial_states_.value();
365
+ TORCH_CHECK(dinitial_states.stride(1) == 1);
366
+ params.dinitial_states_ptr = dinitial_states.data_ptr();
367
+ params.dinitial_states_batch_stride = dinitial_states.stride(0);
368
+ params.dinitial_states_c_stride = dinitial_states.stride(1);
369
+ params.dinitial_states_l_stride = dinitial_states.stride(2);
370
+ } else {
371
+ params.dinitial_states_ptr = nullptr;
372
+ }
373
+
374
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
375
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_bwd", [&] {
376
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_bwd", [&] {
377
+ if (!is_channel_last) {
378
+ causal_conv1d_bwd_cuda<input_t, weight_t>(params, stream);
379
+ } else {
380
+ causal_conv1d_channellast_bwd_cuda<input_t, weight_t>(params, stream);
381
+ }
382
+ });
383
+ });
384
+ }
385
+
386
+ void
387
+ causal_conv1d_update(const at::Tensor &x,
388
+ const at::Tensor &conv_state,
389
+ const at::Tensor &weight,
390
+ const c10::optional<at::Tensor> &bias_,
391
+ at::Tensor &out,
392
+ bool silu_activation,
393
+ const c10::optional<at::Tensor> &cache_seqlens_,
394
+ const c10::optional<at::Tensor> &conv_state_indices_
395
+ ) {
396
+ auto input_type = x.scalar_type();
397
+ auto weight_type = weight.scalar_type();
398
+ TORCH_CHECK(input_type == at::ScalarType::Float || input_type == at::ScalarType::Half || input_type == at::ScalarType::BFloat16);
399
+ TORCH_CHECK(weight_type == at::ScalarType::Float || weight_type == at::ScalarType::Half || weight_type == at::ScalarType::BFloat16);
400
+ TORCH_CHECK(conv_state.scalar_type() == input_type);
401
+
402
+ TORCH_CHECK(x.is_cuda());
403
+ TORCH_CHECK(conv_state.is_cuda());
404
+ TORCH_CHECK(weight.is_cuda());
405
+
406
+ const auto sizes = x.sizes();
407
+ const int batch_size = sizes[0];
408
+ const int dim = sizes[1];
409
+ const int seqlen = sizes[2];
410
+ const int width = weight.size(-1);
411
+ const int conv_state_len = conv_state.size(2);
412
+ TORCH_CHECK(conv_state_len >= width - 1);
413
+
414
+ CHECK_SHAPE(x, batch_size, dim, seqlen);
415
+ CHECK_SHAPE(weight, dim, width);
416
+
417
+ TORCH_CHECK(width >= 2 && width <= 4, "causal_conv1d only supports width between 2 and 4");
418
+
419
+ if (bias_.has_value()) {
420
+ auto bias = bias_.value();
421
+ TORCH_CHECK(bias.scalar_type() == weight_type);
422
+ TORCH_CHECK(bias.is_cuda());
423
+ TORCH_CHECK(bias.stride(-1) == 1);
424
+ CHECK_SHAPE(bias, dim);
425
+ }
426
+
427
+ ConvParamsBase params;
428
+ set_conv_params_fwd(params, batch_size, dim, seqlen, width, x, weight, out,
429
+ bias_.has_value() ? bias_.value().data_ptr() : nullptr,
430
+ silu_activation);
431
+ params.conv_state_ptr = conv_state.data_ptr();
432
+ params.conv_state_len = conv_state_len;
433
+ // All stride are in elements, not bytes.
434
+ params.conv_state_batch_stride = conv_state.stride(0);
435
+ params.conv_state_c_stride = conv_state.stride(1);
436
+ params.conv_state_l_stride = conv_state.stride(2);
437
+
438
+ if (conv_state_indices_.has_value()) {
439
+ auto conv_state_indices = conv_state_indices_.value();
440
+ TORCH_CHECK(conv_state_indices.scalar_type() == torch::kInt32)
441
+ TORCH_CHECK(conv_state_indices.is_cuda());
442
+ TORCH_CHECK(conv_state_indices.stride(0) == 1)
443
+ CHECK_SHAPE(conv_state_indices, batch_size);
444
+
445
+ int conv_state_entries = conv_state.size(0);
446
+ CHECK_SHAPE(conv_state, conv_state_entries, dim, conv_state_len);
447
+
448
+ params.conv_state_indices_ptr = conv_state_indices.data_ptr<int32_t>();
449
+ } else {
450
+ CHECK_SHAPE(conv_state, batch_size, dim, conv_state_len);
451
+ params.conv_state_indices_ptr = nullptr;
452
+ }
453
+
454
+ if (cache_seqlens_.has_value()) {
455
+ auto cache_seqlens = cache_seqlens_.value();
456
+ TORCH_CHECK(cache_seqlens.scalar_type() == torch::kInt32);
457
+ TORCH_CHECK(cache_seqlens.is_cuda());
458
+ TORCH_CHECK(cache_seqlens.stride(-1) == 1);
459
+ CHECK_SHAPE(cache_seqlens, batch_size);
460
+ params.cache_seqlens = cache_seqlens.data_ptr<int32_t>();
461
+ } else {
462
+ params.cache_seqlens = nullptr;
463
+ }
464
+
465
+ // Otherwise the kernel will be launched from cuda:0 device
466
+ #if TORCH_VERSION_MAJOR > 2 || (TORCH_VERSION_MAJOR == 2 && TORCH_VERSION_MINOR >= 6)
467
+ c10::Device device = x.device();
468
+ c10::DeviceGuard device_guard(device);
469
+ #else
470
+ at::cuda::CUDAGuard device_guard{x.device()};
471
+ #endif
472
+ auto stream = at::cuda::getCurrentCUDAStream().stream();
473
+ DISPATCH_ITYPE_FLOAT_AND_HALF_AND_BF16(x.scalar_type(), "causal_conv1d_update", [&] {
474
+ DISPATCH_WTYPE_FLOAT_AND_HALF_AND_BF16(weight.scalar_type(), "causal_conv1d_update", [&] {
475
+ causal_conv1d_update_cuda<input_t, weight_t>(params, stream);
476
+ });
477
+ });
478
+ }
479
+
480
+ /*
481
+ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
482
+ m.def("causal_conv1d_fwd", &causal_conv1d_fwd, "Causal conv1d forward");
483
+ m.def("causal_conv1d_bwd", &causal_conv1d_bwd, "Causal conv1d backward");
484
+ m.def("causal_conv1d_update", &causal_conv1d_update, "Causal conv1d update");
485
+ }
486
+ */
causal-conv1d/causal_conv1d.h ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
8
+
9
+ struct ConvParamsBase {
10
+ using index_t = uint32_t;
11
+
12
+ int batch, dim, seqlen, width;
13
+ bool silu_activation;
14
+
15
+ index_t x_batch_stride;
16
+ index_t x_c_stride;
17
+ index_t x_l_stride;
18
+ index_t weight_c_stride;
19
+ index_t weight_width_stride;
20
+ index_t out_batch_stride;
21
+ index_t out_c_stride;
22
+ index_t out_l_stride;
23
+
24
+ int conv_state_len;
25
+ index_t conv_state_batch_stride;
26
+ index_t conv_state_c_stride;
27
+ index_t conv_state_l_stride;
28
+
29
+ // Common data pointers.
30
+ void *__restrict__ x_ptr;
31
+ void *__restrict__ weight_ptr;
32
+ void *__restrict__ bias_ptr;
33
+ void *__restrict__ out_ptr;
34
+
35
+ void *__restrict__ conv_state_ptr;
36
+ int32_t *__restrict__ cache_seqlens;
37
+
38
+ // Only used if the elements of the batch are gathered from a larger buffer,
39
+ // which may happen for continuous batching.
40
+ int32_t *__restrict__ conv_state_indices_ptr;
41
+
42
+ void *__restrict__ seq_idx_ptr;
43
+
44
+ // No __restrict__ since initial_states could be the same as final_states.
45
+ void * initial_states_ptr;
46
+ index_t initial_states_batch_stride;
47
+ index_t initial_states_l_stride;
48
+ index_t initial_states_c_stride;
49
+
50
+ void * final_states_ptr;
51
+ index_t final_states_batch_stride;
52
+ index_t final_states_l_stride;
53
+ index_t final_states_c_stride;
54
+ };
55
+
56
+ struct ConvParamsBwd: public ConvParamsBase {
57
+ index_t dx_batch_stride;
58
+ index_t dx_c_stride;
59
+ index_t dx_l_stride;
60
+ index_t dweight_c_stride;
61
+ index_t dweight_width_stride;
62
+ index_t dout_batch_stride;
63
+ index_t dout_c_stride;
64
+ index_t dout_l_stride;
65
+
66
+ // Common data pointers.
67
+ void *__restrict__ dx_ptr;
68
+ void *__restrict__ dweight_ptr;
69
+ void *__restrict__ dbias_ptr;
70
+ void *__restrict__ dout_ptr;
71
+
72
+ void * dinitial_states_ptr;
73
+ index_t dinitial_states_batch_stride;
74
+ index_t dinitial_states_l_stride;
75
+ index_t dinitial_states_c_stride;
76
+
77
+ void * dfinal_states_ptr;
78
+ index_t dfinal_states_batch_stride;
79
+ index_t dfinal_states_l_stride;
80
+ index_t dfinal_states_c_stride;
81
+ };
causal-conv1d/causal_conv1d_bwd.cu ADDED
@@ -0,0 +1,627 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #ifndef USE_ROCM
10
+ #include <cub/block/block_load.cuh>
11
+ #include <cub/block/block_store.cuh>
12
+ #include <cub/block/block_reduce.cuh>
13
+ #else
14
+ #include <hipcub/hipcub.hpp>
15
+ namespace cub = hipcub;
16
+ #endif
17
+
18
+ #include "causal_conv1d.h"
19
+ #include "causal_conv1d_common.h"
20
+ #include "static_switch.h"
21
+
22
+ template<int kNThreads_, int kWidth_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
23
+ struct Causal_conv1d_bwd_kernel_traits {
24
+ using input_t = input_t_;
25
+ using weight_t = weight_t_;
26
+ static constexpr int kNThreads = kNThreads_;
27
+ static constexpr int kWidth = kWidth_;
28
+ static constexpr bool kSiluAct = kSiluAct_;
29
+ static constexpr int kNBytes = sizeof(input_t);
30
+ static_assert(kNBytes == 2 || kNBytes == 4);
31
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
32
+ static_assert(kWidth <= kNElts);
33
+ // It's possible that we need to do 2 rounds of exchange if input_t is 16 bits
34
+ // (since then we'd have 8 values of float, and each round we can exchange 4 floats).
35
+ static constexpr int kNExchangeRounds = sizeof(float) / sizeof(input_t);
36
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
37
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
38
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
39
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
40
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
41
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
42
+ using BlockReduceFloatT = cub::BlockReduce<float, kNThreads>;
43
+ static constexpr int kSmemIOSize = kIsVecLoad
44
+ ? 0
45
+ : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
46
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts * (!kSiluAct ? 1 : kNExchangeRounds + 1);
47
+ static constexpr int kSmemSize = custom_max({kSmemExchangeSize,
48
+ int(sizeof(typename BlockReduceFloatT::TempStorage))}) + (kIsVecLoad ? 0 : kSmemIOSize);
49
+ };
50
+
51
+ template<typename Ktraits>
52
+ __global__ __launch_bounds__(Ktraits::kNThreads)
53
+ void causal_conv1d_bwd_kernel(ConvParamsBwd params) {
54
+ constexpr int kWidth = Ktraits::kWidth;
55
+ constexpr int kNThreads = Ktraits::kNThreads;
56
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
57
+ static constexpr int kNElts = Ktraits::kNElts;
58
+ constexpr int kNExchangeRounds = Ktraits::kNExchangeRounds;
59
+ static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
60
+ using input_t = typename Ktraits::input_t;
61
+ using vec_t = typename Ktraits::vec_t;
62
+ using weight_t = typename Ktraits::weight_t;
63
+
64
+ // Shared memory.
65
+ extern __shared__ char smem_[];
66
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
67
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
68
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
69
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
70
+ vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
71
+ vec_t *smem_exchange_x = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize) + kNThreads * kNExchangeRounds;
72
+ auto& smem_reduce_float = *reinterpret_cast<typename Ktraits::BlockReduceFloatT::TempStorage*>(smem_ + Ktraits::kSmemIOSize);
73
+
74
+ const int tidx = threadIdx.x;
75
+ const int batch_id = blockIdx.x;
76
+ const int dim_id = blockIdx.y;
77
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
78
+ + dim_id * params.x_c_stride;
79
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + dim_id * params.weight_c_stride;
80
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
81
+ + dim_id * params.dout_c_stride;
82
+ input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
83
+ + dim_id * params.dx_c_stride;
84
+ float *dweight = reinterpret_cast<float *>(params.dweight_ptr) + dim_id * params.dweight_c_stride;
85
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[dim_id]);
86
+
87
+ // Thread kNThreads - 1 will load the first elements of the next chunk so we initialize those to 0.
88
+ if (tidx == 0) {
89
+ if constexpr (!kSiluAct) {
90
+ input_t zeros[kNElts] = {0};
91
+ smem_exchange[0] = reinterpret_cast<vec_t *>(zeros)[0];
92
+ } else {
93
+ float zeros[kNElts] = {0};
94
+ #pragma unroll
95
+ for (int r = 0; r < kNExchangeRounds; ++r) {
96
+ smem_exchange[r * kNThreads] = reinterpret_cast<vec_t *>(zeros)[r];
97
+ }
98
+ }
99
+ }
100
+
101
+ float weight_vals[kWidth];
102
+ #pragma unroll
103
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = weight[i * params.weight_width_stride]; }
104
+
105
+ float dweight_vals[kWidth] = {0};
106
+ float dbias_val = 0;
107
+
108
+ constexpr int kChunkSize = kNThreads * kNElts;
109
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
110
+ x += (n_chunks - 1) * kChunkSize;
111
+ dout += (n_chunks - 1) * kChunkSize;
112
+ dx += (n_chunks - 1) * kChunkSize;
113
+ for (int chunk = n_chunks - 1; chunk >= 0; --chunk) {
114
+ input_t x_vals_load[2 * kNElts] = {0};
115
+ input_t dout_vals_load[2 * kNElts] = {0};
116
+ if constexpr(kIsVecLoad) {
117
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
118
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(dout), *reinterpret_cast<vec_t (*)[1]>(&dout_vals_load[0]), (params.seqlen - chunk * kChunkSize) / kNElts);
119
+ } else {
120
+ __syncthreads();
121
+ typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
122
+ __syncthreads();
123
+ typename Ktraits::BlockLoadT(smem_load).Load(dout, *reinterpret_cast<input_t (*)[kNElts]>(&dout_vals_load[0]), params.seqlen - chunk * kChunkSize);
124
+ }
125
+ float dout_vals[2 * kNElts], x_vals[2 * kNElts];
126
+ if constexpr (!kSiluAct) {
127
+ __syncthreads();
128
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
129
+ // the first elements of the next chunk.
130
+ if (tidx > 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
131
+ __syncthreads();
132
+ reinterpret_cast<vec_t *>(dout_vals_load)[1] = smem_exchange[tidx < kNThreads - 1 ? tidx + 1 : 0];
133
+ __syncthreads();
134
+ // Now thread 0 can write the first elements of the current chunk.
135
+ if (tidx == 0) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(dout_vals_load)[0]; }
136
+ #pragma unroll
137
+ for (int i = 0; i < 2 * kNElts; ++i) {
138
+ dout_vals[i] = float(dout_vals_load[i]);
139
+ x_vals[i] = float(x_vals_load[i]);
140
+ }
141
+ } else {
142
+ if (tidx == 0 && chunk > 0) {
143
+ if constexpr(kIsVecLoad) {
144
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = reinterpret_cast<vec_t *>(x)[-1];
145
+ } else {
146
+ #pragma unroll
147
+ for (int i = 0; i < kNElts; ++i) {
148
+ if (chunk * kChunkSize + i < params.seqlen) { x_vals_load[i] = x[-kNElts + i]; }
149
+ }
150
+ }
151
+ }
152
+ __syncthreads();
153
+ smem_exchange_x[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1];
154
+ __syncthreads();
155
+ if (tidx > 0) { reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange_x[tidx - 1]; }
156
+ #pragma unroll
157
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
158
+ // Recompute the output
159
+ #pragma unroll
160
+ for (int i = 0; i < kNElts; ++i) {
161
+ float out_val = bias_val;
162
+ #pragma unroll
163
+ for (int w = 0; w < kWidth; ++w) {
164
+ out_val += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
165
+ }
166
+ float out_sigmoid_val = 1.0f / (1.0f + expf(-out_val));
167
+ dout_vals[i] = float(dout_vals_load[i]) * out_sigmoid_val
168
+ * (1.0f + out_val * (1.0f - out_sigmoid_val));
169
+ }
170
+ // Exchange the dout_vals. It's possible that we need to do 2 rounds of exchange
171
+ // if input_t is 16 bits (since then we'd have 8 values of float)
172
+ __syncthreads();
173
+ // Thread 0 don't write yet, so that thread kNThreads - 1 can read
174
+ // the first elements of the next chunk.
175
+ if (tidx > 0) {
176
+ #pragma unroll
177
+ for (int r = 0; r < kNExchangeRounds; ++r) {
178
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
179
+ }
180
+ }
181
+ __syncthreads();
182
+ #pragma unroll
183
+ for (int r = 0; r < kNExchangeRounds; ++r) {
184
+ reinterpret_cast<vec_t *>(dout_vals)[kNExchangeRounds + r]
185
+ = smem_exchange[r * kNThreads + (tidx < kNThreads - 1 ? tidx + 1 : 0)];
186
+ }
187
+ __syncthreads();
188
+ // Now thread 0 can write the first elements of the current chunk.
189
+ if (tidx == 0) {
190
+ #pragma unroll
191
+ for (int r = 0; r < kNExchangeRounds; ++r) {
192
+ smem_exchange[r * kNThreads + tidx] = reinterpret_cast<vec_t *>(dout_vals)[r];
193
+ }
194
+ }
195
+ }
196
+ dout -= kChunkSize;
197
+ x -= kChunkSize;
198
+
199
+ #pragma unroll
200
+ for (int i = 0; i < kNElts; ++i) { dbias_val += dout_vals[i]; }
201
+
202
+ float dx_vals[kNElts] = {0};
203
+ #pragma unroll
204
+ for (int i = 0; i < kNElts; ++i) {
205
+ #pragma unroll
206
+ for (int w = 0; w < kWidth; ++w) {
207
+ dx_vals[i] += weight_vals[w] * dout_vals[i + kWidth - w - 1];
208
+ }
209
+ }
210
+
211
+ input_t dx_vals_store[kNElts];
212
+ #pragma unroll
213
+ for (int i = 0; i < kNElts; ++i) { dx_vals_store[i] = dx_vals[i]; }
214
+ if constexpr(kIsVecLoad) {
215
+ typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(dx), reinterpret_cast<vec_t (&)[1]>(dx_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
216
+ } else {
217
+ typename Ktraits::BlockStoreT(smem_store).Store(dx, dx_vals_store, params.seqlen - chunk * kChunkSize);
218
+ }
219
+ dx -= kChunkSize;
220
+
221
+ #pragma unroll
222
+ for (int w = 0; w < kWidth; ++w) {
223
+ #pragma unroll
224
+ for (int i = 0; i < kNElts; ++i) {
225
+ dweight_vals[w] += x_vals[kNElts + i] * dout_vals[i + kWidth - w - 1];
226
+ }
227
+ }
228
+ }
229
+
230
+ #pragma unroll
231
+ for (int w = 0; w < kWidth; ++w) {
232
+ __syncthreads();
233
+ dweight_vals[w] = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dweight_vals[w]);
234
+ if (tidx == 0) {
235
+ atomicAdd(&reinterpret_cast<float *>(dweight)[w * params.dweight_width_stride], dweight_vals[w]);
236
+ }
237
+ }
238
+ if (params.bias_ptr != nullptr) {
239
+ __syncthreads();
240
+ dbias_val = typename Ktraits::BlockReduceFloatT(smem_reduce_float).Sum(dbias_val);
241
+ if (tidx == 0) {
242
+ atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[dim_id], dbias_val);
243
+ }
244
+ }
245
+ }
246
+
247
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
248
+ void causal_conv1d_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
249
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
250
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
251
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
252
+ using Ktraits = Causal_conv1d_bwd_kernel_traits<kNThreads, kWidth, kSiluAct, kIsVecLoad, input_t, weight_t>;
253
+ constexpr int kSmemSize = Ktraits::kSmemSize;
254
+ dim3 grid(params.batch, params.dim);
255
+ auto kernel = &causal_conv1d_bwd_kernel<Ktraits>;
256
+
257
+ if (kSmemSize >= 48 * 1024) {
258
+ #ifndef USE_ROCM
259
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
260
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
261
+ #else
262
+ // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
263
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
264
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
265
+ std::cerr << "Warning (causal_conv1d bwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
266
+ #endif
267
+ }
268
+
269
+
270
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
271
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
272
+ });
273
+ });
274
+ }
275
+
276
+ template<typename input_t, typename weight_t>
277
+ void causal_conv1d_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
278
+ if (params.width == 2) {
279
+ causal_conv1d_bwd_launch<128, 2, input_t, weight_t>(params, stream);
280
+ } else if (params.width == 3) {
281
+ causal_conv1d_bwd_launch<128, 3, input_t, weight_t>(params, stream);
282
+ } else if (params.width == 4) {
283
+ causal_conv1d_bwd_launch<128, 4, input_t, weight_t>(params, stream);
284
+ }
285
+ }
286
+
287
+ template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kSiluAct_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
288
+ struct Causal_conv1d_channellast_bwd_kernel_traits {
289
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
290
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
291
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
292
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
293
+ using input_t = input_t_;
294
+ using weight_t = weight_t_;
295
+ static constexpr bool kSiluAct = kSiluAct_;
296
+ static constexpr int kNThreads = kNThreads_;
297
+ static_assert(kNThreads % 32 == 0);
298
+ static constexpr int kNWarps = kNThreads / 32;
299
+ static constexpr int kWidth = kWidth_;
300
+ static constexpr int kChunkSizeL = kChunkSizeL_;
301
+ static constexpr int kNBytes = sizeof(input_t);
302
+ static_assert(kNBytes == 2 || kNBytes == 4);
303
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
304
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
305
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
306
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
307
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
308
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
309
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
310
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
311
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
312
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
313
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
314
+ // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
315
+ // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
316
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
317
+ // sizeof(typename BlockStoreT::TempStorage)});
318
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
319
+ };
320
+
321
+ template<typename Ktraits, bool kHasSeqIdx, bool kHasDfinalStates>
322
+ __global__ __launch_bounds__(Ktraits::kNThreads)
323
+ void causal_conv1d_channellast_bwd_kernel(ConvParamsBwd params) {
324
+ constexpr int kWidth = Ktraits::kWidth;
325
+ constexpr int kNThreads = Ktraits::kNThreads;
326
+ constexpr bool kSiluAct = Ktraits::kSiluAct;
327
+ constexpr int kNElts = Ktraits::kNElts;
328
+ constexpr int kNWarp = Ktraits::kNWarps;
329
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
330
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
331
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
332
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
333
+ using input_t = typename Ktraits::input_t;
334
+ using vec_t = typename Ktraits::vec_t;
335
+ using weight_t = typename Ktraits::weight_t;
336
+
337
+ // Shared memory.
338
+ __shared__ input_t dout_smem[kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
339
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL + kWidth - 1][kChunkSizeC + kNElts];
340
+
341
+ const int batch_id = blockIdx.x;
342
+ const int chunk_l_id = blockIdx.y;
343
+ const int chunk_c_id = blockIdx.z;
344
+ const int tid = threadIdx.x;
345
+ const int l_idx = tid / kNThreadsPerC;
346
+ const int c_idx = tid % kNThreadsPerC;
347
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
348
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
349
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
350
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
351
+ input_t *dout = reinterpret_cast<input_t *>(params.dout_ptr) + batch_id * params.dout_batch_stride
352
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dout_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
353
+ input_t *dx = reinterpret_cast<input_t *>(params.dx_ptr) + batch_id * params.dx_batch_stride
354
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.dx_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
355
+ float *dweight = reinterpret_cast<float *>(params.dweight_ptr)
356
+ + chunk_c_id * kChunkSizeC * params.dweight_c_stride;
357
+ int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
358
+ + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
359
+ input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
360
+ : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
361
+ input_t *dinitial_states = params.dinitial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
362
+ : reinterpret_cast<input_t *>(params.dinitial_states_ptr) + batch_id * params.dinitial_states_batch_stride + l_idx * params.dinitial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
363
+ input_t *dfinal_states = params.dfinal_states_ptr == nullptr ? nullptr
364
+ : reinterpret_cast<input_t *>(params.dfinal_states_ptr) + batch_id * params.dfinal_states_batch_stride + chunk_c_id * kChunkSizeC;
365
+
366
+ #pragma unroll
367
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
368
+ input_t dout_vals_load[kNElts] = {0};
369
+ input_t x_vals_load[kNElts] = {0};
370
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
371
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
372
+ reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + l * kLPerLoad * params.dout_l_stride);
373
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
374
+ }
375
+ reinterpret_cast<vec_t *>(dout_smem[l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
376
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
377
+ }
378
+ // Load the elements from the previous chunk or next chunk that are needed for convolution.
379
+ if (l_idx < kWidth - 1) {
380
+ input_t dout_vals_load[kNElts] = {0};
381
+ input_t x_vals_load[kNElts] = {0};
382
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
383
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
384
+ reinterpret_cast<vec_t *>(dout_vals_load)[0] = *reinterpret_cast<vec_t *>(dout + kChunkSizeL * params.dout_l_stride);
385
+ }
386
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
387
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
388
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
389
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
390
+ } else if (initial_states != nullptr
391
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
392
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
393
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
394
+ }
395
+ reinterpret_cast<vec_t *>(dout_smem[kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(dout_vals_load)[0];
396
+ reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
397
+ }
398
+ // Need to load (kWdith - 1) extra x's on the right to recompute the (kChunkSizeL + kWidth - 1) outputs
399
+ if constexpr (kSiluAct) {
400
+ if (l_idx < kWidth - 1) {
401
+ input_t x_vals_load[kNElts] = {0};
402
+ if ((chunk_l_id + 1) * kChunkSizeL + l_idx < params.seqlen
403
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
404
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + kChunkSizeL * params.x_l_stride);
405
+ }
406
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + kChunkSizeL + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
407
+ }
408
+ }
409
+
410
+ __syncthreads();
411
+
412
+ constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
413
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
414
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
415
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
416
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
417
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
418
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
419
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
420
+ static_assert(kNThreadsPerRow <= 32);
421
+
422
+ const int row_idx = tid / kNThreadsPerRow;
423
+ const int col_idx = tid % kNThreadsPerRow;
424
+
425
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
426
+ float weight_vals[kWidth] = {0};
427
+ if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
428
+ #pragma unroll
429
+ for (int w = 0; w < kWidth; ++w) {
430
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
431
+ }
432
+ }
433
+ float dout_vals[kLPerThread + kWidth - 1];
434
+ float x_vals[kWidth - 1 + kLPerThread + kWidth - 1];
435
+ #pragma unroll
436
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
437
+ dout_vals[i] = float(dout_smem[col_idx * kLPerThread + i][row_idx]);
438
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
439
+ }
440
+
441
+ int seq_idx_thread[kWidth - 1 + kLPerThread + kWidth - 1];
442
+ if constexpr (kHasSeqIdx) {
443
+ #pragma unroll
444
+ for (int i = 0; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
445
+ const int l_idx = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1);
446
+ seq_idx_thread[i] = l_idx >= 0 && l_idx < params.seqlen ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
447
+ }
448
+ }
449
+
450
+ if constexpr (kSiluAct) { // Recompute the output
451
+ #pragma unroll
452
+ for (int i = kWidth - 1 + kLPerThread; i < kWidth - 1 + kLPerThread + kWidth - 1; ++i) {
453
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
454
+ }
455
+ #pragma unroll
456
+ for (int i = 0; i < kLPerThread + kWidth - 1; ++i) {
457
+ float out_val = bias_val;
458
+ const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
459
+ #pragma unroll
460
+ for (int w = 0; w < kWidth; ++w) {
461
+ if constexpr (!kHasSeqIdx) {
462
+ out_val += weight_vals[w] * x_vals[i + w];
463
+ } else {
464
+ out_val += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
465
+ }
466
+ }
467
+ float out_val_sigmoid = 1.f / (1.f + expf(-out_val));
468
+ dout_vals[i] *= out_val_sigmoid * (1 + out_val * (1 - out_val_sigmoid));
469
+ }
470
+ }
471
+
472
+ float dweight_vals[kWidth] = {0};
473
+ SumOp<float> sum_op;
474
+ #pragma unroll
475
+ for (int w = 0; w < kWidth; ++w) {
476
+ #pragma unroll
477
+ for (int i = 0; i < kLPerThread; ++i) {
478
+ if constexpr (!kHasSeqIdx) {
479
+ dweight_vals[w] += x_vals[i + w] * dout_vals[i];
480
+ } else {
481
+ dweight_vals[w] += seq_idx_thread[i + w] == seq_idx_thread[kWidth - 1 + i] ? x_vals[i + w] * dout_vals[i] : 0.f;
482
+ }
483
+ }
484
+ dweight_vals[w] = Allreduce<kNThreadsPerRow>::run(dweight_vals[w], sum_op);
485
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
486
+ atomicAdd(&reinterpret_cast<float *>(dweight)[row_idx * params.dweight_c_stride + w * params.dweight_width_stride], dweight_vals[w]);
487
+ }
488
+ }
489
+
490
+ if (params.bias_ptr != nullptr) {
491
+ float dbias_val = 0.f;
492
+ for (int i = 0; i < kLPerThread; ++i) { dbias_val += dout_vals[i]; }
493
+ dbias_val = Allreduce<kNThreadsPerRow>::run(dbias_val, sum_op);
494
+ if (col_idx == 0 && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
495
+ atomicAdd(&reinterpret_cast<float *>(params.dbias_ptr)[chunk_c_id * kChunkSizeC + row_idx], dbias_val);
496
+ }
497
+ }
498
+
499
+ float dx_vals[kLPerThread] = {0};
500
+ #pragma unroll
501
+ for (int i = 0; i < kLPerThread; ++i) {
502
+ const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
503
+ #pragma unroll
504
+ for (int w = 0; w < kWidth; ++w) {
505
+ if constexpr (!kHasSeqIdx) {
506
+ dx_vals[i] += weight_vals[kWidth - 1 - w] * dout_vals[i + w];
507
+ } else {
508
+ dx_vals[i] += seq_idx_thread[kWidth - 1 + i + w] == seq_idx_cur ? weight_vals[kWidth - 1 - w] * dout_vals[i + w] : 0.f;
509
+ }
510
+ }
511
+ // if (dfinal_states != nullptr) {
512
+ if constexpr (kHasDfinalStates) {
513
+ if (chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i >= params.seqlen - kWidth + 1
514
+ && chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i < params.seqlen
515
+ && chunk_c_id * kChunkSizeC + row_idx < params.dim) {
516
+ dx_vals[i] += float(dfinal_states[((chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i) - (params.seqlen - kWidth + 1)) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
517
+ }
518
+ }
519
+ }
520
+
521
+ float dxinit_vals[kWidth - 1] = {0};
522
+ static_assert(kLPerThread >= kWidth - 1); // So only threads with col_idx == 0 need to handle dinitial_states
523
+ if (dinitial_states != nullptr && col_idx == 0) {
524
+ #pragma unroll
525
+ for (int i = 0; i < kWidth - 1; ++i) {
526
+ #pragma unroll
527
+ for (int w = 0; w < kWidth; ++w) {
528
+ dxinit_vals[i] += i + w - (kWidth - 1) >= 0 ? weight_vals[kWidth - 1 - w] * dout_vals[i + w - (kWidth - 1)] : 0.f;
529
+ }
530
+ // chunk_l_id must be 0 because dinitial_states != nullptr
531
+ // if (dfinal_states != nullptr) {
532
+ if constexpr (kHasDfinalStates) {
533
+ if (i >= params.seqlen) {
534
+ dxinit_vals[i] += float(dfinal_states[(i - params.seqlen) * params.dfinal_states_l_stride + row_idx * params.dfinal_states_c_stride]);
535
+ }
536
+ }
537
+ }
538
+ }
539
+
540
+ __syncthreads();
541
+ #pragma unroll
542
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[kWidth - 1 + col_idx * kLPerThread + i][row_idx] = dx_vals[i]; }
543
+ if (dinitial_states != nullptr && col_idx == 0) {
544
+ #pragma unroll
545
+ for (int i = 0; i < kWidth - 1; ++i) { x_smem[i][row_idx] = dxinit_vals[i]; }
546
+ }
547
+ __syncthreads();
548
+
549
+ #pragma unroll
550
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
551
+ input_t dx_vals_store[kNElts];
552
+ reinterpret_cast<vec_t *>(dx_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx];
553
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
554
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
555
+ *reinterpret_cast<vec_t *>(dx + l * kLPerLoad * params.dx_l_stride) = reinterpret_cast<vec_t *>(dx_vals_store)[0];
556
+ }
557
+ }
558
+ if (dinitial_states != nullptr
559
+ && l_idx < kWidth - 1
560
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
561
+ input_t dxinit_vals_store[kNElts];
562
+ reinterpret_cast<vec_t *>(dxinit_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx];
563
+ *reinterpret_cast<vec_t *>(dinitial_states) = reinterpret_cast<vec_t *>(dxinit_vals_store)[0];
564
+ }
565
+
566
+ }
567
+
568
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
569
+ void causal_conv1d_channellast_bwd_launch(ConvParamsBwd &params, cudaStream_t stream) {
570
+ BOOL_SWITCH(params.silu_activation, kSiluAct, [&] {
571
+ BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
572
+ BOOL_SWITCH(params.dfinal_states_ptr != nullptr, kHasDfinalStates, [&] {
573
+ BOOL_SWITCH(params.seqlen <= 128, kChunkSizeL64, [&] {
574
+ // kChunkSizeL = 128 is slightly faster than 64 when seqlen is larger
575
+ static constexpr int kChunk = kChunkSizeL64 ? 64 : 128;
576
+ using Ktraits = Causal_conv1d_channellast_bwd_kernel_traits<kNThreads, kWidth, kChunk, kSiluAct, true, input_t, weight_t>;
577
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
578
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
579
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
580
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
581
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
582
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
583
+ dim3 block(Ktraits::kNThreads);
584
+ auto kernel = &causal_conv1d_channellast_bwd_kernel<Ktraits, kHasSeqIdx, kHasDfinalStates>;
585
+ // if (kSmemSize >= 48 * 1024) {
586
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
587
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
588
+ // }
589
+ // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
590
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
591
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
592
+ });
593
+ });
594
+ });
595
+ });
596
+ }
597
+
598
+ template<typename input_t, typename weight_t>
599
+ void causal_conv1d_channellast_bwd_cuda(ConvParamsBwd &params, cudaStream_t stream) {
600
+ if (params.width == 2) {
601
+ causal_conv1d_channellast_bwd_launch<128, 2, input_t, weight_t>(params, stream);
602
+ } else if (params.width == 3) {
603
+ causal_conv1d_channellast_bwd_launch<128, 3, input_t, weight_t>(params, stream);
604
+ } else if (params.width == 4) {
605
+ causal_conv1d_channellast_bwd_launch<128, 4, input_t, weight_t>(params, stream);
606
+ }
607
+ }
608
+
609
+ template void causal_conv1d_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
610
+ template void causal_conv1d_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
611
+ template void causal_conv1d_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
612
+ template void causal_conv1d_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
613
+ template void causal_conv1d_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
614
+ template void causal_conv1d_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
615
+ template void causal_conv1d_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
616
+ template void causal_conv1d_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
617
+ template void causal_conv1d_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
618
+
619
+ template void causal_conv1d_channellast_bwd_cuda<float, float>(ConvParamsBwd &params, cudaStream_t stream);
620
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, float>(ConvParamsBwd &params, cudaStream_t stream);
621
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, float>(ConvParamsBwd &params, cudaStream_t stream);
622
+ template void causal_conv1d_channellast_bwd_cuda<float, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
623
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
624
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::Half>(ConvParamsBwd &params, cudaStream_t stream);
625
+ template void causal_conv1d_channellast_bwd_cuda<float, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
626
+ template void causal_conv1d_channellast_bwd_cuda<at::Half, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
627
+ template void causal_conv1d_channellast_bwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBwd &params, cudaStream_t stream);
causal-conv1d/causal_conv1d_common.h ADDED
@@ -0,0 +1,98 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #pragma once
6
+
7
+ #ifndef USE_ROCM
8
+ #include <cuda_bf16.h>
9
+
10
+ template<typename T>
11
+ __device__ inline T shuffle_xor(T val, int offset) {
12
+ return __shfl_xor_sync(uint32_t(-1), val, offset);
13
+ }
14
+
15
+ constexpr size_t custom_max(std::initializer_list<size_t> ilist)
16
+ {
17
+ return std::max(ilist);
18
+ }
19
+
20
+ template<typename T>
21
+ constexpr T constexpr_min(T a, T b) {
22
+ return std::min(a, b);
23
+ }
24
+
25
+ #else
26
+ #include <hip/hip_bf16.h>
27
+
28
+ template<typename T>
29
+ __device__ inline T shuffle_xor(T val, int offset) {
30
+ return __shfl_xor(val, offset);
31
+ }
32
+ constexpr size_t custom_max(std::initializer_list<size_t> ilist)
33
+ {
34
+ return *std::max_element(ilist.begin(), ilist.end());
35
+ }
36
+
37
+ template<typename T>
38
+ constexpr T constexpr_min(T a, T b) {
39
+ return a < b ? a : b;
40
+ }
41
+ #endif
42
+ #include <cuda_fp16.h>
43
+
44
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
45
+
46
+ template<int BYTES> struct BytesToType {};
47
+
48
+ template<> struct BytesToType<16> {
49
+ using Type = uint4;
50
+ static_assert(sizeof(Type) == 16);
51
+ };
52
+
53
+ template<> struct BytesToType<8> {
54
+ using Type = uint64_t;
55
+ static_assert(sizeof(Type) == 8);
56
+ };
57
+
58
+ template<> struct BytesToType<4> {
59
+ using Type = uint32_t;
60
+ static_assert(sizeof(Type) == 4);
61
+ };
62
+
63
+ template<> struct BytesToType<2> {
64
+ using Type = uint16_t;
65
+ static_assert(sizeof(Type) == 2);
66
+ };
67
+
68
+ template<> struct BytesToType<1> {
69
+ using Type = uint8_t;
70
+ static_assert(sizeof(Type) == 1);
71
+ };
72
+
73
+ ////////////////////////////////////////////////////////////////////////////////////////////////////
74
+
75
+ template<typename T>
76
+ struct SumOp {
77
+ __device__ inline T operator()(T const & x, T const & y) { return x + y; }
78
+ };
79
+
80
+ template<int THREADS>
81
+ struct Allreduce {
82
+ static_assert(THREADS == 32 || THREADS == 16 || THREADS == 8 || THREADS == 4);
83
+ template<typename T, typename Operator>
84
+ static __device__ inline T run(T x, Operator &op) {
85
+ constexpr int OFFSET = THREADS / 2;
86
+ x = op(x, shuffle_xor(x, OFFSET));
87
+ return Allreduce<OFFSET>::run(x, op);
88
+ }
89
+ };
90
+
91
+ template<>
92
+ struct Allreduce<2> {
93
+ template<typename T, typename Operator>
94
+ static __device__ inline T run(T x, Operator &op) {
95
+ x = op(x, shuffle_xor(x, 1));
96
+ return x;
97
+ }
98
+ };
causal-conv1d/causal_conv1d_fwd.cu ADDED
@@ -0,0 +1,399 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2024, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #ifndef USE_ROCM
10
+ #include <cub/block/block_load.cuh>
11
+ #include <cub/block/block_store.cuh>
12
+ #else
13
+ #include <hipcub/hipcub.hpp>
14
+ namespace cub = hipcub;
15
+ #endif
16
+
17
+ #include "causal_conv1d.h"
18
+ #include "causal_conv1d_common.h"
19
+ #include "static_switch.h"
20
+
21
+ template<int kNThreads_, int kWidth_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
22
+ struct Causal_conv1d_fwd_kernel_traits {
23
+ using input_t = input_t_;
24
+ using weight_t = weight_t_;
25
+ static constexpr int kNThreads = kNThreads_;
26
+ static constexpr int kWidth = kWidth_;
27
+ static constexpr int kNBytes = sizeof(input_t);
28
+ static_assert(kNBytes == 2 || kNBytes == 4);
29
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
30
+ static_assert(kWidth <= kNElts);
31
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
32
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
33
+ using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNElts, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
34
+ using BlockLoadVecT = cub::BlockLoad<vec_t, kNThreads, 1, cub::BLOCK_LOAD_DIRECT>;
35
+ using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNElts, cub::BLOCK_STORE_WARP_TRANSPOSE>;
36
+ using BlockStoreVecT = cub::BlockStore<vec_t, kNThreads, 1, cub::BLOCK_STORE_DIRECT>;
37
+ static constexpr int kSmemIOSize = kIsVecLoad
38
+ ? 0
39
+ : custom_max({sizeof(typename BlockLoadT::TempStorage), sizeof(typename BlockStoreT::TempStorage)});
40
+ static constexpr int kSmemExchangeSize = kNThreads * kNBytes * kNElts;
41
+ static constexpr int kSmemSize = kSmemIOSize + kSmemExchangeSize;
42
+ };
43
+
44
+ template<typename Ktraits>
45
+ __global__ __launch_bounds__(Ktraits::kNThreads)
46
+ void causal_conv1d_fwd_kernel(ConvParamsBase params) {
47
+ constexpr int kWidth = Ktraits::kWidth;
48
+ constexpr int kNThreads = Ktraits::kNThreads;
49
+ constexpr int kNElts = Ktraits::kNElts;
50
+ static constexpr bool kIsVecLoad = Ktraits::kIsVecLoad;
51
+ using input_t = typename Ktraits::input_t;
52
+ using vec_t = typename Ktraits::vec_t;
53
+ using weight_t = typename Ktraits::weight_t;
54
+
55
+ // Shared memory.
56
+ extern __shared__ char smem_[];
57
+ auto& smem_load = reinterpret_cast<typename Ktraits::BlockLoadT::TempStorage&>(smem_);
58
+ auto& smem_load_vec = reinterpret_cast<typename Ktraits::BlockLoadVecT::TempStorage&>(smem_);
59
+ auto& smem_store = reinterpret_cast<typename Ktraits::BlockStoreT::TempStorage&>(smem_);
60
+ auto& smem_store_vec = reinterpret_cast<typename Ktraits::BlockStoreVecT::TempStorage&>(smem_);
61
+ vec_t *smem_exchange = reinterpret_cast<vec_t *>(smem_ + Ktraits::kSmemIOSize);
62
+
63
+ const int tidx = threadIdx.x;
64
+ const int batch_id = blockIdx.x;
65
+ const int channel_id = blockIdx.y;
66
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
67
+ + channel_id * params.x_c_stride;
68
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
69
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
70
+ + channel_id * params.out_c_stride;
71
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
72
+
73
+ // Thread 0 will load the last elements of the previous chunk, so we initialize those to 0.
74
+ if (tidx == 0) {
75
+ input_t zeros[kNElts] = {0};
76
+ smem_exchange[kNThreads - 1] = reinterpret_cast<vec_t *>(zeros)[0];
77
+ }
78
+
79
+ float weight_vals[kWidth];
80
+ #pragma unroll
81
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
82
+
83
+ constexpr int kChunkSize = kNThreads * kNElts;
84
+ const int n_chunks = (params.seqlen + kChunkSize - 1) / kChunkSize;
85
+ for (int chunk = 0; chunk < n_chunks; ++chunk) {
86
+ input_t x_vals_load[2 * kNElts] = {0};
87
+ if constexpr(kIsVecLoad) {
88
+ typename Ktraits::BlockLoadVecT(smem_load_vec).Load(reinterpret_cast<vec_t*>(x), *reinterpret_cast<vec_t (*)[1]>(&x_vals_load[kNElts]), (params.seqlen - chunk * kChunkSize) / kNElts);
89
+ } else {
90
+ __syncthreads();
91
+ typename Ktraits::BlockLoadT(smem_load).Load(x, *reinterpret_cast<input_t (*)[kNElts]>(&x_vals_load[kNElts]), params.seqlen - chunk * kChunkSize);
92
+ }
93
+ x += kChunkSize;
94
+ __syncthreads();
95
+ // Thread kNThreads - 1 don't write yet, so that thread 0 can read
96
+ // the last elements of the previous chunk.
97
+ if (tidx < kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
98
+ __syncthreads();
99
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = smem_exchange[tidx > 0 ? tidx - 1 : kNThreads - 1];
100
+ __syncthreads();
101
+ // Now thread kNThreads - 1 can write the last elements of the current chunk.
102
+ if (tidx == kNThreads - 1) { smem_exchange[tidx] = reinterpret_cast<vec_t *>(x_vals_load)[1]; }
103
+
104
+ float x_vals[2 * kNElts];
105
+ #pragma unroll
106
+ for (int i = 0; i < 2 * kNElts; ++i) { x_vals[i] = float(x_vals_load[i]); }
107
+
108
+ float out_vals[kNElts];
109
+ #pragma unroll
110
+ for (int i = 0; i < kNElts; ++i) {
111
+ out_vals[i] = bias_val;
112
+ #pragma unroll
113
+ for (int w = 0; w < kWidth; ++w) {
114
+ out_vals[i] += weight_vals[w] * x_vals[kNElts + i - (kWidth - w - 1)];
115
+ }
116
+ }
117
+
118
+ if (params.silu_activation) {
119
+ #pragma unroll
120
+ for (int i = 0; i < kNElts; ++i) {
121
+ out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i]));
122
+ }
123
+ }
124
+
125
+ input_t out_vals_store[kNElts];
126
+ #pragma unroll
127
+ for (int i = 0; i < kNElts; ++i) { out_vals_store[i] = out_vals[i]; }
128
+ if constexpr(kIsVecLoad) {
129
+ typename Ktraits::BlockStoreVecT(smem_store_vec).Store(reinterpret_cast<vec_t*>(out), reinterpret_cast<vec_t (&)[1]>(out_vals_store), (params.seqlen - chunk * kChunkSize) / kNElts);
130
+ } else {
131
+ typename Ktraits::BlockStoreT(smem_store).Store(out, out_vals_store, params.seqlen - chunk * kChunkSize);
132
+ }
133
+ out += kChunkSize;
134
+ }
135
+ }
136
+
137
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
138
+ void causal_conv1d_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
139
+ static constexpr int kNElts = sizeof(input_t) == 4 ? 4 : 8;
140
+ BOOL_SWITCH(params.seqlen % kNElts == 0, kIsVecLoad, [&] {
141
+ using Ktraits = Causal_conv1d_fwd_kernel_traits<kNThreads, kWidth, kIsVecLoad, input_t, weight_t>;
142
+ constexpr int kSmemSize = Ktraits::kSmemSize;
143
+ dim3 grid(params.batch, params.dim);
144
+
145
+ auto kernel = &causal_conv1d_fwd_kernel<Ktraits>;
146
+
147
+ if (kSmemSize >= 48 * 1024) {
148
+ #ifndef USE_ROCM
149
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
150
+ kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
151
+ #else
152
+ // There is a slight signature discrepancy in HIP and CUDA "FuncSetAttribute" function.
153
+ C10_CUDA_CHECK(cudaFuncSetAttribute(
154
+ (void *) kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
155
+ std::cerr << "Warning (causal_conv1d fwd launch): attempting to set maxDynamicSharedMemorySize on an AMD GPU which is currently a non-op (in ROCm versions <= 6.1). This might lead to undefined behavior. \n" << std::endl;
156
+ #endif
157
+ }
158
+ kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
159
+
160
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
161
+ });
162
+ }
163
+
164
+ template<typename input_t, typename weight_t>
165
+ void causal_conv1d_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
166
+ if (params.width == 2) {
167
+ causal_conv1d_fwd_launch<128, 2, input_t, weight_t>(params, stream);
168
+ } else if (params.width == 3) {
169
+ causal_conv1d_fwd_launch<128, 3, input_t, weight_t>(params, stream);
170
+ } else if (params.width == 4) {
171
+ causal_conv1d_fwd_launch<128, 4, input_t, weight_t>(params, stream);
172
+ }
173
+ }
174
+
175
+ template<int kNThreads_, int kWidth_, int kChunkSizeL_, bool kIsVecLoad_, typename input_t_, typename weight_t_>
176
+ struct Causal_conv1d_channellast_fwd_kernel_traits {
177
+ // The cache line is 128 bytes, and we try to read 16 bytes per thread.
178
+ // So we have 8 threads per "row", so 32 or 64 elements in the channel dimension.
179
+ // That leaves 4 columns per warp, and so 16 columns per block (assuming each block has 128
180
+ // threads). Each each load is 16 x 32|64 elements in the L x C dimensions.
181
+ using input_t = input_t_;
182
+ using weight_t = weight_t_;
183
+ static constexpr int kNThreads = kNThreads_;
184
+ static_assert(kNThreads % 32 == 0);
185
+ static constexpr int kNWarps = kNThreads / 32;
186
+ static constexpr int kWidth = kWidth_;
187
+ static constexpr int kChunkSizeL = kChunkSizeL_;
188
+ static constexpr int kNBytes = sizeof(input_t);
189
+ static_assert(kNBytes == 2 || kNBytes == 4);
190
+ static constexpr int kNElts = kNBytes == 4 ? 4 : 8;
191
+ static constexpr int kNEltsPerRow = 128 / kNBytes;
192
+ static constexpr int kNThreadsPerRow = kNEltsPerRow / kNElts; // Always 8 for now
193
+ static_assert(kNThreadsPerRow * kNBytes * kNElts == 128);
194
+ static constexpr int kNColsPerWarp = 32 / kNThreadsPerRow; // Always 4 for now
195
+ static_assert(kNColsPerWarp * kNThreadsPerRow == 32);
196
+ static constexpr int kNColsPerLoad = kNColsPerWarp * kNWarps;
197
+ static constexpr int kNLoads = kChunkSizeL / kNColsPerLoad;
198
+ static_assert(kNLoads * kNColsPerLoad == kChunkSizeL);
199
+ static constexpr bool kIsVecLoad = kIsVecLoad_;
200
+ using vec_t = typename BytesToType<kNBytes * kNElts>::Type;
201
+ // using BlockLoadT = cub::BlockLoad<input_t, kNThreads, kNItems, cub::BLOCK_LOAD_WARP_TRANSPOSE>;
202
+ // using BlockStoreT = cub::BlockStore<input_t, kNThreads, kNItems, cub::BLOCK_STORE_WARP_TRANSPOSE>;
203
+ // static constexpr int kSmemSize = std::max({sizeof(typename BlockLoadT::TempStorage),
204
+ // sizeof(typename BlockStoreT::TempStorage)});
205
+ // static constexpr int kSmemSize = kChunkSizeL * kNEltsPerRow * kNBytes;
206
+ };
207
+
208
+ template<typename Ktraits, bool kHasSeqIdx>
209
+ __global__ __launch_bounds__(Ktraits::kNThreads)
210
+ void causal_conv1d_channellast_fwd_kernel(ConvParamsBase params) {
211
+ constexpr int kWidth = Ktraits::kWidth;
212
+ constexpr int kNThreads = Ktraits::kNThreads;
213
+ constexpr int kNElts = Ktraits::kNElts;
214
+ constexpr int kNWarp = Ktraits::kNWarps;
215
+ constexpr int kNThreadsPerC = Ktraits::kNThreadsPerRow;
216
+ constexpr int kLPerLoad = Ktraits::kNColsPerLoad;
217
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
218
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
219
+ using input_t = typename Ktraits::input_t;
220
+ using vec_t = typename Ktraits::vec_t;
221
+ using weight_t = typename Ktraits::weight_t;
222
+
223
+ // Shared memory.
224
+ __shared__ input_t x_smem[kWidth - 1 + kChunkSizeL][kChunkSizeC + kNElts];
225
+
226
+ const int batch_id = blockIdx.x;
227
+ const int chunk_l_id = blockIdx.y;
228
+ const int chunk_c_id = blockIdx.z;
229
+ const int tid = threadIdx.x;
230
+ const int l_idx = tid / kNThreadsPerC;
231
+ const int c_idx = tid % kNThreadsPerC;
232
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
233
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.x_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
234
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr)
235
+ + chunk_c_id * kChunkSizeC * params.weight_c_stride;
236
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
237
+ + (chunk_l_id * kChunkSizeL + l_idx) * params.out_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
238
+ int *seq_idx = !kHasSeqIdx ? nullptr : reinterpret_cast<int *>(params.seq_idx_ptr)
239
+ + batch_id * params.seqlen + chunk_l_id * kChunkSizeL;
240
+ input_t *initial_states = params.initial_states_ptr == nullptr || chunk_l_id > 0 ? nullptr
241
+ : reinterpret_cast<input_t *>(params.initial_states_ptr) + batch_id * params.initial_states_batch_stride + l_idx * params.initial_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
242
+ // The last L-chunk will also have enough info to write to final states, since it also contain a few x values
243
+ // from the previous L-chunk.
244
+ input_t *final_states = params.final_states_ptr == nullptr || chunk_l_id < gridDim.y - 1 ? nullptr
245
+ : reinterpret_cast<input_t *>(params.final_states_ptr) + batch_id * params.final_states_batch_stride + l_idx * params.final_states_l_stride + chunk_c_id * kChunkSizeC + c_idx * kNElts;
246
+
247
+ #pragma unroll
248
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
249
+ input_t x_vals_load[kNElts] = {0};
250
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
251
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
252
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x + l * kLPerLoad * params.x_l_stride);
253
+ }
254
+ reinterpret_cast<vec_t *>(x_smem[kWidth - 1 + l * kLPerLoad + l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
255
+ }
256
+ // Load the elements from the previous chunk that are needed for convolution.
257
+ if (l_idx < kWidth - 1) {
258
+ input_t x_vals_load[kNElts] = {0};
259
+ if (chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) >= 0
260
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < params.seqlen
261
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
262
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(x - (kWidth - 1) * params.x_l_stride);
263
+ } else if (initial_states != nullptr
264
+ && chunk_l_id * kChunkSizeL + l_idx - (kWidth - 1) < 0
265
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
266
+ reinterpret_cast<vec_t *>(x_vals_load)[0] = *reinterpret_cast<vec_t *>(initial_states);
267
+ }
268
+ reinterpret_cast<vec_t *>(x_smem[l_idx])[c_idx] = reinterpret_cast<vec_t *>(x_vals_load)[0];
269
+ }
270
+
271
+ __syncthreads();
272
+
273
+ if (final_states != nullptr
274
+ && l_idx < kWidth - 1
275
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
276
+ // x_smem[0] contains element at index chunk_l_id * kChunkSizeL - (kWidth - 1)
277
+ // So last few elements (index params.seqlen - kWidth + 1 + l_idx) are stored in x_smem[params.seqlen - kWidth + 1 + l_idx - (chunk_l_id * kChunkSizeL - kWidth + 1)][c_idx]
278
+ *reinterpret_cast<vec_t *>(final_states) = reinterpret_cast<vec_t *>(x_smem[params.seqlen + l_idx - chunk_l_id * kChunkSizeL])[c_idx];
279
+ }
280
+
281
+ constexpr int kLPerThread = constexpr_min(kChunkSizeL * kChunkSizeC / kNThreads, kChunkSizeL);
282
+ static_assert(kLPerThread * kNThreads == kChunkSizeL * kChunkSizeC);
283
+ constexpr int kNThreadsPerRow = kChunkSizeL / kLPerThread;
284
+ static_assert(kNThreadsPerRow * kLPerThread == kChunkSizeL);
285
+ // kChunkSizeL, kLPerThread, kNThreadsPerRow should be powers of 2 for simplicity
286
+ static_assert((kChunkSizeL & (kChunkSizeL - 1)) == 0);
287
+ static_assert((kLPerThread & (kLPerThread - 1)) == 0);
288
+ static_assert((kNThreadsPerRow & (kNThreadsPerRow - 1)) == 0);
289
+ static_assert(kNThreadsPerRow <= 32);
290
+
291
+ const int row_idx = tid / kNThreadsPerRow;
292
+ const int col_idx = tid % kNThreadsPerRow;
293
+
294
+ float bias_val = params.bias_ptr == nullptr || chunk_c_id * kChunkSizeC + row_idx >= params.dim ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[chunk_c_id * kChunkSizeC + row_idx]);
295
+ float weight_vals[kWidth] = {0};
296
+ if (chunk_c_id * kChunkSizeC + row_idx < params.dim) {
297
+ #pragma unroll
298
+ for (int w = 0; w < kWidth; ++w) {
299
+ weight_vals[w] = weight[row_idx * params.weight_c_stride + w * params.weight_width_stride];
300
+ }
301
+ }
302
+ float x_vals[kWidth - 1 + kLPerThread];
303
+ #pragma unroll
304
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
305
+ x_vals[i] = float(x_smem[col_idx * kLPerThread + i][row_idx]);
306
+ }
307
+ int seq_idx_thread[kWidth - 1 + kLPerThread];
308
+ if constexpr (kHasSeqIdx) {
309
+ #pragma unroll
310
+ for (int i = 0; i < kWidth - 1 + kLPerThread; ++i) {
311
+ seq_idx_thread[i] = chunk_l_id * kChunkSizeL + col_idx * kLPerThread + i - (kWidth - 1) >= 0 ? seq_idx[col_idx * kLPerThread + i - (kWidth - 1)] : -1;
312
+ }
313
+ }
314
+
315
+ float out_vals[kLPerThread];
316
+ #pragma unroll
317
+ for (int i = 0; i < kLPerThread; ++i) {
318
+ out_vals[i] = bias_val;
319
+ const int seq_idx_cur = !kHasSeqIdx ? 0 : seq_idx_thread[i + kWidth - 1];
320
+ #pragma unroll
321
+ for (int w = 0; w < kWidth; ++w) {
322
+ if constexpr (!kHasSeqIdx) {
323
+ out_vals[i] += weight_vals[w] * x_vals[i + w];
324
+ } else {
325
+ out_vals[i] += seq_idx_thread[i + w] == seq_idx_cur ? weight_vals[w] * x_vals[i + w] : 0.f;
326
+ }
327
+ }
328
+ if (params.silu_activation) {out_vals[i] = out_vals[i] / (1 + expf(-out_vals[i])); }
329
+ }
330
+
331
+ __syncthreads();
332
+ #pragma unroll
333
+ for (int i = 0; i < kLPerThread; ++i) { x_smem[col_idx * kLPerThread + i][row_idx] = out_vals[i]; }
334
+ __syncthreads();
335
+
336
+ #pragma unroll
337
+ for (int l = 0; l < Ktraits::kNLoads; ++l) {
338
+ input_t out_vals_store[kNElts];
339
+ reinterpret_cast<vec_t *>(out_vals_store)[0] = reinterpret_cast<vec_t *>(x_smem[l * kLPerLoad + l_idx])[c_idx];
340
+ if (chunk_l_id * kChunkSizeL + l * kLPerLoad + l_idx < params.seqlen
341
+ && chunk_c_id * kChunkSizeC + c_idx * kNElts < params.dim) {
342
+ *reinterpret_cast<vec_t *>(out + l * kLPerLoad * params.out_l_stride) = reinterpret_cast<vec_t *>(out_vals_store)[0];
343
+ }
344
+ }
345
+
346
+ }
347
+
348
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
349
+ void causal_conv1d_channellast_fwd_launch(ConvParamsBase &params, cudaStream_t stream) {
350
+ BOOL_SWITCH(params.seq_idx_ptr != nullptr, kHasSeqIdx, [&] {
351
+ using Ktraits = Causal_conv1d_channellast_fwd_kernel_traits<kNThreads, kWidth, 64, true, input_t, weight_t>;
352
+ // constexpr int kSmemSize = Ktraits::kSmemSize;
353
+ constexpr int kChunkSizeL = Ktraits::kChunkSizeL;
354
+ constexpr int kChunkSizeC = Ktraits::kNEltsPerRow;
355
+ const int n_chunks_L = (params.seqlen + kChunkSizeL - 1) / kChunkSizeL;
356
+ const int n_chunks_C = (params.dim + kChunkSizeC - 1) / kChunkSizeC;
357
+ dim3 grid(params.batch, n_chunks_L, n_chunks_C);
358
+ dim3 block(Ktraits::kNThreads);
359
+ auto kernel = &causal_conv1d_channellast_fwd_kernel<Ktraits, kHasSeqIdx>;
360
+ // if (kSmemSize >= 48 * 1024) {
361
+ // C10_CUDA_CHECK(cudaFuncSetAttribute(
362
+ // kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, kSmemSize));
363
+ // }
364
+ // kernel<<<grid, Ktraits::kNThreads, kSmemSize, stream>>>(params);
365
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
366
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
367
+ });
368
+ }
369
+
370
+ template<typename input_t, typename weight_t>
371
+ void causal_conv1d_channellast_fwd_cuda(ConvParamsBase &params, cudaStream_t stream) {
372
+ if (params.width == 2) {
373
+ causal_conv1d_channellast_fwd_launch<128, 2, input_t, weight_t>(params, stream);
374
+ } else if (params.width == 3) {
375
+ causal_conv1d_channellast_fwd_launch<128, 3, input_t, weight_t>(params, stream);
376
+ } else if (params.width == 4) {
377
+ causal_conv1d_channellast_fwd_launch<128, 4, input_t, weight_t>(params, stream);
378
+ }
379
+ }
380
+
381
+ template void causal_conv1d_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
382
+ template void causal_conv1d_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
383
+ template void causal_conv1d_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
384
+ template void causal_conv1d_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
385
+ template void causal_conv1d_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
386
+ template void causal_conv1d_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
387
+ template void causal_conv1d_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
388
+ template void causal_conv1d_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
389
+ template void causal_conv1d_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
390
+
391
+ template void causal_conv1d_channellast_fwd_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
392
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
393
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
394
+ template void causal_conv1d_channellast_fwd_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
395
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
396
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
397
+ template void causal_conv1d_channellast_fwd_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
398
+ template void causal_conv1d_channellast_fwd_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
399
+ template void causal_conv1d_channellast_fwd_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
causal-conv1d/causal_conv1d_update.cu ADDED
@@ -0,0 +1,137 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /******************************************************************************
2
+ * Copyright (c) 2023, Tri Dao.
3
+ ******************************************************************************/
4
+
5
+ #include <c10/util/BFloat16.h>
6
+ #include <c10/util/Half.h>
7
+ #include <c10/cuda/CUDAException.h> // For C10_CUDA_CHECK and C10_CUDA_KERNEL_LAUNCH_CHECK
8
+
9
+ #include "causal_conv1d.h"
10
+ #include "causal_conv1d_common.h"
11
+ #include "static_switch.h"
12
+
13
+ template<int kNThreads_, int kWidth_, typename input_t_, typename weight_t_>
14
+ struct Causal_conv1d_update_kernel_traits {
15
+ using input_t = input_t_;
16
+ using weight_t = weight_t_;
17
+ static constexpr int kNThreads = kNThreads_;
18
+ static constexpr int kWidth = kWidth_;
19
+ static constexpr int kNBytes = sizeof(input_t);
20
+ static_assert(kNBytes == 2 || kNBytes == 4);
21
+ };
22
+
23
+ template<typename Ktraits, bool kIsCircularBuffer>
24
+ __global__ __launch_bounds__(Ktraits::kNThreads)
25
+ void causal_conv1d_update_kernel(ConvParamsBase params) {
26
+ constexpr int kWidth = Ktraits::kWidth;
27
+ constexpr int kNThreads = Ktraits::kNThreads;
28
+ using input_t = typename Ktraits::input_t;
29
+ using weight_t = typename Ktraits::weight_t;
30
+
31
+ const int tidx = threadIdx.x;
32
+ const int batch_id = blockIdx.x;
33
+ const int channel_id = blockIdx.y * kNThreads + tidx;
34
+ if (channel_id >= params.dim) return;
35
+
36
+ input_t *x = reinterpret_cast<input_t *>(params.x_ptr) + batch_id * params.x_batch_stride
37
+ + channel_id * params.x_c_stride;
38
+
39
+ // If params.conv_state_batch_indices is set, then the conv state is gathered from the conv state tensor
40
+ // along the batch axis. Otherwise, the conv state coordinate is the same as the batch id.
41
+ const int conv_state_batch_coord = params.conv_state_indices_ptr == nullptr
42
+ ? batch_id
43
+ : params.conv_state_indices_ptr[batch_id];
44
+ input_t *conv_state = reinterpret_cast<input_t *>(params.conv_state_ptr)
45
+ + conv_state_batch_coord * params.conv_state_batch_stride
46
+ + channel_id * params.conv_state_c_stride;
47
+ weight_t *weight = reinterpret_cast<weight_t *>(params.weight_ptr) + channel_id * params.weight_c_stride;
48
+ input_t *out = reinterpret_cast<input_t *>(params.out_ptr) + batch_id * params.out_batch_stride
49
+ + channel_id * params.out_c_stride;
50
+ float bias_val = params.bias_ptr == nullptr ? 0.f : float(reinterpret_cast<weight_t *>(params.bias_ptr)[channel_id]);
51
+
52
+ int state_len = params.conv_state_len;
53
+ int advance_len = params.seqlen;
54
+ int cache_seqlen = kIsCircularBuffer ? params.cache_seqlens[batch_id] % state_len : 0;
55
+ int update_idx = cache_seqlen - (kWidth - 1);
56
+ update_idx = update_idx < 0 ? update_idx + state_len : update_idx;
57
+
58
+ float weight_vals[kWidth] = {0};
59
+ #pragma unroll
60
+ for (int i = 0; i < kWidth; ++i) { weight_vals[i] = float(weight[i * params.weight_width_stride]); }
61
+
62
+ float x_vals[kWidth] = {0};
63
+ if constexpr (!kIsCircularBuffer) {
64
+ #pragma unroll 2
65
+ for (int i = 0; i < state_len - advance_len - (kWidth - 1); ++i) {
66
+ conv_state[i * params.conv_state_l_stride] = conv_state[(i + advance_len) * params.conv_state_l_stride];
67
+ }
68
+ #pragma unroll
69
+ for (int i = 0; i < kWidth - 1; ++i) {
70
+ input_t state_val = conv_state[(state_len - (kWidth - 1) + i) * params.conv_state_l_stride];
71
+ if (i < advance_len + (kWidth - 1) && state_len - advance_len - (kWidth - 1) + i >= 0) {
72
+ conv_state[(state_len - advance_len - (kWidth - 1) + i) * params.conv_state_l_stride] = state_val;
73
+ }
74
+ x_vals[i] = float(state_val);
75
+ }
76
+ } else {
77
+ #pragma unroll
78
+ for (int i = 0; i < kWidth - 1; ++i, update_idx = update_idx + 1 >= state_len ? update_idx + 1 - state_len : update_idx + 1) {
79
+ input_t state_val = conv_state[update_idx * params.conv_state_l_stride];
80
+ x_vals[i] = float(state_val);
81
+ }
82
+ }
83
+ #pragma unroll 2
84
+ for (int i = 0; i < params.seqlen; ++i) {
85
+ input_t x_val = x[i * params.x_l_stride];
86
+ if constexpr (!kIsCircularBuffer) {
87
+ if (i < advance_len && state_len - advance_len + i >= 0) {
88
+ conv_state[(state_len - advance_len + i) * params.conv_state_l_stride] = x_val;
89
+ }
90
+ } else {
91
+ conv_state[update_idx * params.conv_state_l_stride] = x_val;
92
+ ++update_idx;
93
+ update_idx = update_idx >= state_len ? update_idx - state_len : update_idx;
94
+ }
95
+ x_vals[kWidth - 1] = float(x_val);
96
+ float out_val = bias_val;
97
+ #pragma unroll
98
+ for (int j = 0; j < kWidth; ++j) { out_val += weight_vals[j] * x_vals[j]; }
99
+ if (params.silu_activation) { out_val = out_val / (1 + expf(-out_val)); }
100
+ out[i * params.out_l_stride] = input_t(out_val);
101
+ // Shift the input buffer by 1
102
+ #pragma unroll
103
+ for (int i = 0; i < kWidth - 1; ++i) { x_vals[i] = x_vals[i + 1]; }
104
+ }
105
+ }
106
+
107
+ template<int kNThreads, int kWidth, typename input_t, typename weight_t>
108
+ void causal_conv1d_update_launch(ConvParamsBase &params, cudaStream_t stream) {
109
+ using Ktraits = Causal_conv1d_update_kernel_traits<kNThreads, kWidth, input_t, weight_t>;
110
+ dim3 grid(params.batch, (params.dim + kNThreads - 1) / kNThreads);
111
+ auto kernel = params.cache_seqlens == nullptr
112
+ ? &causal_conv1d_update_kernel<Ktraits, false>
113
+ : &causal_conv1d_update_kernel<Ktraits, true>;
114
+ kernel<<<grid, Ktraits::kNThreads, 0, stream>>>(params);
115
+ C10_CUDA_KERNEL_LAUNCH_CHECK();
116
+ }
117
+
118
+ template<typename input_t, typename weight_t>
119
+ void causal_conv1d_update_cuda(ConvParamsBase &params, cudaStream_t stream) {
120
+ if (params.width == 2) {
121
+ causal_conv1d_update_launch<64, 2, input_t, weight_t>(params, stream);
122
+ } else if (params.width == 3) {
123
+ causal_conv1d_update_launch<64, 3, input_t, weight_t>(params, stream);
124
+ } else if (params.width == 4) {
125
+ causal_conv1d_update_launch<64, 4, input_t, weight_t>(params, stream);
126
+ }
127
+ }
128
+
129
+ template void causal_conv1d_update_cuda<float, float>(ConvParamsBase &params, cudaStream_t stream);
130
+ template void causal_conv1d_update_cuda<at::Half, float>(ConvParamsBase &params, cudaStream_t stream);
131
+ template void causal_conv1d_update_cuda<at::BFloat16, float>(ConvParamsBase &params, cudaStream_t stream);
132
+ template void causal_conv1d_update_cuda<float, at::Half>(ConvParamsBase &params, cudaStream_t stream);
133
+ template void causal_conv1d_update_cuda<at::Half, at::Half>(ConvParamsBase &params, cudaStream_t stream);
134
+ template void causal_conv1d_update_cuda<at::BFloat16, at::Half>(ConvParamsBase &params, cudaStream_t stream);
135
+ template void causal_conv1d_update_cuda<float, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
136
+ template void causal_conv1d_update_cuda<at::Half, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
137
+ template void causal_conv1d_update_cuda<at::BFloat16, at::BFloat16>(ConvParamsBase &params, cudaStream_t stream);
causal-conv1d/static_switch.h ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Inspired by https://github.com/NVIDIA/DALI/blob/main/include/dali/core/static_switch.h
2
+ // and https://github.com/pytorch/pytorch/blob/master/aten/src/ATen/Dispatch.h
3
+
4
+ #pragma once
5
+
6
+ /// @param COND - a boolean expression to switch by
7
+ /// @param CONST_NAME - a name given for the constexpr bool variable.
8
+ /// @param ... - code to execute for true and false
9
+ ///
10
+ /// Usage:
11
+ /// ```
12
+ /// BOOL_SWITCH(flag, BoolConst, [&] {
13
+ /// some_function<BoolConst>(...);
14
+ /// });
15
+ /// ```
16
+ #define BOOL_SWITCH(COND, CONST_NAME, ...) \
17
+ [&] { \
18
+ if (COND) { \
19
+ static constexpr bool CONST_NAME = true; \
20
+ return __VA_ARGS__(); \
21
+ } else { \
22
+ static constexpr bool CONST_NAME = false; \
23
+ return __VA_ARGS__(); \
24
+ } \
25
+ }()
flake.lock ADDED
@@ -0,0 +1,168 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nodes": {
3
+ "flake-compat": {
4
+ "locked": {
5
+ "lastModified": 1747046372,
6
+ "narHash": "sha256-CIVLLkVgvHYbgI2UpXvIIBJ12HWgX+fjA8Xf8PUmqCY=",
7
+ "owner": "edolstra",
8
+ "repo": "flake-compat",
9
+ "rev": "9100a0f413b0c601e0533d1d94ffd501ce2e7885",
10
+ "type": "github"
11
+ },
12
+ "original": {
13
+ "owner": "edolstra",
14
+ "repo": "flake-compat",
15
+ "type": "github"
16
+ }
17
+ },
18
+ "flake-compat_2": {
19
+ "locked": {
20
+ "lastModified": 1733328505,
21
+ "narHash": "sha256-NeCCThCEP3eCl2l/+27kNNK7QrwZB1IJCrXfrbv5oqU=",
22
+ "owner": "edolstra",
23
+ "repo": "flake-compat",
24
+ "rev": "ff81ac966bb2cae68946d5ed5fc4994f96d0ffec",
25
+ "type": "github"
26
+ },
27
+ "original": {
28
+ "owner": "edolstra",
29
+ "repo": "flake-compat",
30
+ "type": "github"
31
+ }
32
+ },
33
+ "flake-utils": {
34
+ "inputs": {
35
+ "systems": "systems"
36
+ },
37
+ "locked": {
38
+ "lastModified": 1731533236,
39
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
40
+ "owner": "numtide",
41
+ "repo": "flake-utils",
42
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
43
+ "type": "github"
44
+ },
45
+ "original": {
46
+ "owner": "numtide",
47
+ "repo": "flake-utils",
48
+ "type": "github"
49
+ }
50
+ },
51
+ "flake-utils_2": {
52
+ "inputs": {
53
+ "systems": "systems_2"
54
+ },
55
+ "locked": {
56
+ "lastModified": 1731533236,
57
+ "narHash": "sha256-l0KFg5HjrsfsO/JpG+r7fRrqm12kzFHyUHqHCVpMMbI=",
58
+ "owner": "numtide",
59
+ "repo": "flake-utils",
60
+ "rev": "11707dc2f618dd54ca8739b309ec4fc024de578b",
61
+ "type": "github"
62
+ },
63
+ "original": {
64
+ "owner": "numtide",
65
+ "repo": "flake-utils",
66
+ "type": "github"
67
+ }
68
+ },
69
+ "hf-nix": {
70
+ "inputs": {
71
+ "flake-compat": "flake-compat_2",
72
+ "flake-utils": "flake-utils_2",
73
+ "nixpkgs": "nixpkgs"
74
+ },
75
+ "locked": {
76
+ "lastModified": 1754038838,
77
+ "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
78
+ "owner": "huggingface",
79
+ "repo": "hf-nix",
80
+ "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
81
+ "type": "github"
82
+ },
83
+ "original": {
84
+ "owner": "huggingface",
85
+ "repo": "hf-nix",
86
+ "type": "github"
87
+ }
88
+ },
89
+ "kernel-builder": {
90
+ "inputs": {
91
+ "flake-compat": "flake-compat",
92
+ "flake-utils": "flake-utils",
93
+ "hf-nix": "hf-nix",
94
+ "nixpkgs": [
95
+ "kernel-builder",
96
+ "hf-nix",
97
+ "nixpkgs"
98
+ ]
99
+ },
100
+ "locked": {
101
+ "lastModified": 1755181472,
102
+ "narHash": "sha256-xOXjhehC5xi/XB4fXZ5c0L2sSyDjJQdlH7/BcdHLBaM=",
103
+ "owner": "huggingface",
104
+ "repo": "kernel-builder",
105
+ "rev": "85da46f660c1c43b40771c3df3b223bb3fa39bec",
106
+ "type": "github"
107
+ },
108
+ "original": {
109
+ "owner": "huggingface",
110
+ "repo": "kernel-builder",
111
+ "type": "github"
112
+ }
113
+ },
114
+ "nixpkgs": {
115
+ "locked": {
116
+ "lastModified": 1752785354,
117
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
118
+ "owner": "nixos",
119
+ "repo": "nixpkgs",
120
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
121
+ "type": "github"
122
+ },
123
+ "original": {
124
+ "owner": "nixos",
125
+ "repo": "nixpkgs",
126
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
127
+ "type": "github"
128
+ }
129
+ },
130
+ "root": {
131
+ "inputs": {
132
+ "kernel-builder": "kernel-builder"
133
+ }
134
+ },
135
+ "systems": {
136
+ "locked": {
137
+ "lastModified": 1681028828,
138
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
139
+ "owner": "nix-systems",
140
+ "repo": "default",
141
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
142
+ "type": "github"
143
+ },
144
+ "original": {
145
+ "owner": "nix-systems",
146
+ "repo": "default",
147
+ "type": "github"
148
+ }
149
+ },
150
+ "systems_2": {
151
+ "locked": {
152
+ "lastModified": 1681028828,
153
+ "narHash": "sha256-Vy1rq5AaRuLzOxct8nz4T6wlgyUR7zLU309k9mBC768=",
154
+ "owner": "nix-systems",
155
+ "repo": "default",
156
+ "rev": "da67096a3b9bf56a91d16901293e51ba5b49a27e",
157
+ "type": "github"
158
+ },
159
+ "original": {
160
+ "owner": "nix-systems",
161
+ "repo": "default",
162
+ "type": "github"
163
+ }
164
+ }
165
+ },
166
+ "root": "root",
167
+ "version": 7
168
+ }
flake.nix ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ description = "Flake for attention kernels";
3
+
4
+ inputs = {
5
+ kernel-builder.url = "github:huggingface/kernel-builder";
6
+ };
7
+
8
+ outputs =
9
+ {
10
+ self,
11
+ kernel-builder,
12
+ }:
13
+ kernel-builder.lib.genFlakeOutputs {
14
+ path = ./.;
15
+ rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ pythonCheckInputs = pkgs: with pkgs; [ einops ];
17
+ };
18
+ }
tests/test_causal_conv1d.py ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (C) 2024, Tri Dao.
2
+
3
+ import math
4
+
5
+ import torch
6
+ import torch.nn.functional as F
7
+
8
+ import pytest
9
+
10
+ from einops import rearrange
11
+
12
+
13
+ from causal_conv1d import causal_conv1d_fn, causal_conv1d_update, causal_conv1d_varlen_states
14
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_ref
15
+ from causal_conv1d.causal_conv1d_interface import causal_conv1d_update_ref
16
+ from causal_conv1d.causal_conv1d_varlen import causal_conv1d_varlen_states_ref
17
+
18
+
19
+ @pytest.mark.parametrize("return_final_states", [False, True])
20
+ # @pytest.mark.parametrize("return_final_states", [True])
21
+ @pytest.mark.parametrize("has_initial_states", [False, True])
22
+ # @pytest.mark.parametrize("has_initial_states", [False])
23
+ @pytest.mark.parametrize("channel_last", [False, True])
24
+ # @pytest.mark.parametrize('channel_last', [True])
25
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
26
+ # @pytest.mark.parametrize('itype', [torch.float16])
27
+ @pytest.mark.parametrize("silu_activation", [False, True])
28
+ # @pytest.mark.parametrize('silu_activation', [True])
29
+ @pytest.mark.parametrize("has_bias", [False, True])
30
+ # @pytest.mark.parametrize('has_bias', [True])
31
+ @pytest.mark.parametrize("width", [2, 3, 4])
32
+ # @pytest.mark.parametrize('width', [3])
33
+ @pytest.mark.parametrize(
34
+ "seqlen", [1, 2, 8, 16, 32, 64, 128, 129, 130, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
35
+ )
36
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
37
+ # @pytest.mark.parametrize('seqlen', [128])
38
+ @pytest.mark.parametrize('dim', [64, 4096 + 32])
39
+ # @pytest.mark.parametrize('dim', [64])
40
+ def test_causal_conv1d(dim, seqlen, width, has_bias, silu_activation, itype, channel_last, has_initial_states, return_final_states):
41
+ if not channel_last and (has_initial_states or return_final_states):
42
+ pytest.skip("Only channel_last support initial_states or return_final_states")
43
+ device = "cuda"
44
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
45
+ if itype == torch.bfloat16:
46
+ rtol, atol = 1e-2, 5e-2
47
+ rtolw, atolw = (1e-3, 1e-3)
48
+ # set seed
49
+ torch.random.manual_seed(0)
50
+ batch = 2
51
+ # batch = 1
52
+ if not channel_last:
53
+ x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
54
+ else:
55
+ x = rearrange(
56
+ torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
57
+ ).requires_grad_()
58
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
59
+ if has_bias:
60
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
61
+ else:
62
+ bias = None
63
+ if has_initial_states:
64
+ initial_states = torch.randn(batch, width - 1, dim, device=device, dtype=itype).transpose(1, 2).requires_grad_()
65
+ else:
66
+ initial_states = None
67
+ x_ref = x.detach().clone().requires_grad_()
68
+ weight_ref = weight.detach().clone().requires_grad_()
69
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
70
+ initial_states_ref = initial_states.detach().clone().requires_grad_() if initial_states is not None else None
71
+ activation = None if not silu_activation else "silu"
72
+ out = causal_conv1d_fn(x, weight, bias, initial_states=initial_states, return_final_states=return_final_states,
73
+ activation=activation)
74
+ out_ref = causal_conv1d_ref(x_ref, weight_ref, bias_ref, initial_states=initial_states_ref, return_final_states=return_final_states, activation=activation)
75
+ if return_final_states:
76
+ out, final_states = out
77
+ out_ref, final_states_ref = out_ref
78
+ print(f"Final states max diff: {(final_states - final_states_ref).abs().max().item()}")
79
+ print(f"Final states mean diff: {(final_states - final_states_ref).abs().mean().item()}")
80
+ assert torch.allclose(final_states, final_states_ref, rtol=rtol, atol=atol)
81
+
82
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
83
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
84
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
85
+
86
+ if return_final_states:
87
+ out += F.sigmoid(final_states).sum(dim=-1, keepdim=True)
88
+ out_ref += F.sigmoid(final_states_ref).sum(dim=-1, keepdim=True)
89
+
90
+ g = torch.randn_like(out)
91
+ out.backward(g)
92
+ out_ref.backward(g)
93
+
94
+ print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
95
+ print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
96
+ if has_bias:
97
+ print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
98
+ if has_initial_states:
99
+ print(f"dinitial_states max diff: {(initial_states.grad - initial_states_ref.grad).abs().max().item()}")
100
+
101
+ assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
102
+ assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
103
+ if has_bias:
104
+ assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
105
+ if has_initial_states:
106
+ assert torch.allclose(initial_states.grad, initial_states_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
107
+
108
+
109
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
110
+ # @pytest.mark.parametrize('itype', [torch.float16])
111
+ @pytest.mark.parametrize("silu_activation", [False, True])
112
+ # @pytest.mark.parametrize('silu_activation', [True])
113
+ @pytest.mark.parametrize("has_bias", [False, True])
114
+ # @pytest.mark.parametrize('has_bias', [True])
115
+ @pytest.mark.parametrize("has_cache_seqlens", [False, True])
116
+ # @pytest.mark.parametrize('has_cache_seqlens', [True])
117
+ @pytest.mark.parametrize("seqlen", [1, 4, 5])
118
+ # @pytest.mark.parametrize('seqlen', [4])
119
+ @pytest.mark.parametrize("width", [2, 3, 4])
120
+ # @pytest.mark.parametrize('width', [4])
121
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
122
+ # @pytest.mark.parametrize("dim", [2048])
123
+ def test_causal_conv1d_update(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
124
+ device = "cuda"
125
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
126
+ if itype == torch.bfloat16:
127
+ rtol, atol = 1e-2, 5e-2
128
+ rtolw, atolw = (1e-3, 1e-3)
129
+ # set seed
130
+ torch.random.manual_seed(0)
131
+ batch = 64
132
+ # batch = 1
133
+ # dim = 64
134
+ x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
135
+ state_len = torch.randint(width - 1, width + 10, (1,)).item()
136
+ conv_state = torch.randn(batch, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
137
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
138
+ if has_bias:
139
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
140
+ else:
141
+ bias = None
142
+ conv_state_ref = conv_state.detach().clone()
143
+ activation = None if not silu_activation else "silu"
144
+ cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
145
+ if has_cache_seqlens else None)
146
+ out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
147
+ out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
148
+
149
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
150
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
151
+ assert torch.equal(conv_state, conv_state_ref)
152
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
153
+
154
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
155
+ # @pytest.mark.parametrize('itype', [torch.float16])
156
+ @pytest.mark.parametrize("silu_activation", [False, True])
157
+ # @pytest.mark.parametrize('silu_activation', [True])
158
+ @pytest.mark.parametrize("has_bias", [False, True])
159
+ # @pytest.mark.parametrize('has_bias', [True])
160
+ @pytest.mark.parametrize("has_cache_seqlens", [False, True])
161
+ # @pytest.mark.parametrize('has_cache_seqlens', [True])
162
+ @pytest.mark.parametrize("seqlen", [1, 4, 5])
163
+ # @pytest.mark.parametrize('seqlen', [4])
164
+ @pytest.mark.parametrize("width", [2, 3, 4])
165
+ # @pytest.mark.parametrize('width', [4])
166
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
167
+ # @pytest.mark.parametrize("dim", [2048])
168
+ def test_causal_conv1d_update_with_batch_gather(dim, width, seqlen, has_cache_seqlens, has_bias, silu_activation, itype):
169
+ device = "cuda"
170
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
171
+ if itype == torch.bfloat16:
172
+ rtol, atol = 1e-2, 5e-2
173
+ rtolw, atolw = (1e-3, 1e-3)
174
+ # set seed
175
+ torch.random.manual_seed(0)
176
+ batch = 64
177
+ # batch = 1
178
+ # dim = 64
179
+ x = torch.randn(batch, seqlen, dim, device=device, dtype=itype).transpose(-1, -2)
180
+ state_len = torch.randint(width - 1, width + 10, (1,)).item()
181
+
182
+ total_entries = 10 * batch
183
+ conv_state = torch.randn(total_entries, state_len, dim, device=device, dtype=itype).transpose(-1, -2)
184
+ conv_state_indices = torch.randperm(total_entries)[:batch].to(dtype=torch.int32, device=device)
185
+
186
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
187
+ if has_bias:
188
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
189
+ else:
190
+ bias = None
191
+ conv_state_ref = conv_state[conv_state_indices, :].detach().clone()
192
+ activation = None if not silu_activation else "silu"
193
+ cache_seqlens = (torch.randint(0, 1024, (batch,), dtype=torch.int32, device=device)
194
+ if has_cache_seqlens else None)
195
+ out = causal_conv1d_update(x, conv_state, weight, bias, activation=activation,
196
+ cache_seqlens=cache_seqlens, conv_state_indices=conv_state_indices)
197
+ out_ref = causal_conv1d_update_ref(x, conv_state_ref, weight, bias, activation=activation, cache_seqlens=cache_seqlens)
198
+
199
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
200
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
201
+ assert torch.equal(conv_state[conv_state_indices, :], conv_state_ref)
202
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
203
+
204
+
205
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
206
+ # @pytest.mark.parametrize('itype', [torch.float16])
207
+ @pytest.mark.parametrize("dim", [2048, 2048 + 16, 4096])
208
+ # @pytest.mark.parametrize("dim", [2048])
209
+ def test_causal_conv1d_get_states(dim, itype):
210
+ device = "cuda"
211
+ # set seed
212
+ torch.random.manual_seed(0)
213
+ seqlens = torch.randint(1, 32, (100,), device=device)
214
+ total_seqlen = seqlens.sum().item()
215
+ x = torch.randn(total_seqlen, dim, device=device, dtype=itype)
216
+ cu_seqlens = F.pad(seqlens.cumsum(0), (1, 0))
217
+ state_len = 20
218
+ out = causal_conv1d_varlen_states(x, cu_seqlens, state_len)
219
+ out_ref = causal_conv1d_varlen_states_ref(x, cu_seqlens, state_len)
220
+ assert torch.equal(out, out_ref)
221
+
222
+
223
+ # @pytest.mark.parametrize("channel_last", [False, True])
224
+ @pytest.mark.parametrize('channel_last', [True])
225
+ # @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
226
+ @pytest.mark.parametrize('itype', [torch.bfloat16])
227
+ # @pytest.mark.parametrize("silu_activation", [False, True])
228
+ @pytest.mark.parametrize('silu_activation', [True])
229
+ # @pytest.mark.parametrize("has_bias", [False, True])
230
+ @pytest.mark.parametrize('has_bias', [True])
231
+ # @pytest.mark.parametrize("width", [2, 3, 4])
232
+ @pytest.mark.parametrize('width', [4])
233
+ @pytest.mark.parametrize(
234
+ # "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
235
+ "seqlen", [2048]
236
+ )
237
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
238
+ # @pytest.mark.parametrize('seqlen', [128])
239
+ def test_causal_conv1d_race_condition(seqlen, width, has_bias, silu_activation, itype, channel_last):
240
+ device = "cuda"
241
+ # set seed
242
+ torch.random.manual_seed(0)
243
+ batch = 2
244
+ # batch = 1
245
+ dim = 4096 + 32 # Try dim not divisible by 64
246
+ # dim = 64
247
+ if not channel_last:
248
+ x = torch.randn(batch, 4096 + dim + 64, seqlen, device=device, dtype=itype)[:, 4096:4096 + dim, :].requires_grad_()
249
+ else:
250
+ x = rearrange(
251
+ torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
252
+ ).requires_grad_()
253
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
254
+ if has_bias:
255
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
256
+ else:
257
+ bias = None
258
+ activation = None if not silu_activation else "silu"
259
+ out0 = causal_conv1d_fn(x, weight, bias, activation=activation)
260
+ g = torch.randn_like(out0)
261
+ dx0, dw0, db0 = torch.autograd.grad(out0, (x, weight, bias), g)
262
+ dw_atol = 1e-4
263
+ db_atol = 1e-4
264
+
265
+ for i in range(10000):
266
+ out = causal_conv1d_fn(x, weight, bias, activation=activation)
267
+ dx, dw, db = torch.autograd.grad(out, (x, weight, bias), g)
268
+ dw_equal = torch.allclose(dw, dw0, atol=dw_atol)
269
+ # if not dw_equal:
270
+ # breakpoint()
271
+ if has_bias:
272
+ db_equal = torch.allclose(db, db0, atol=db_atol)
273
+ # if not db_equal:
274
+ # breakpoint()
275
+ assert torch.equal(out, out0)
276
+ assert torch.equal(dx, dx0)
277
+ assert dw_equal
278
+ if has_bias:
279
+ assert dw_equal
280
+
281
+
282
+ @pytest.mark.parametrize("itype", [torch.float32, torch.float16, torch.bfloat16])
283
+ # @pytest.mark.parametrize('itype', [torch.float16])
284
+ @pytest.mark.parametrize("silu_activation", [False, True])
285
+ # @pytest.mark.parametrize('silu_activation', [False])
286
+ @pytest.mark.parametrize("has_bias", [False, True])
287
+ # @pytest.mark.parametrize('has_bias', [False])
288
+ @pytest.mark.parametrize("width", [2, 3, 4])
289
+ # @pytest.mark.parametrize('width', [2])
290
+ @pytest.mark.parametrize(
291
+ "seqlen", [8, 16, 32, 64, 128, 151, 256, 372, 512, 784, 1024, 1134, 2048, 4096]
292
+ )
293
+ # @pytest.mark.parametrize('seqlen', [8, 16, 32, 64, 128, 256, 512, 784, 1024, 2048, 4096])
294
+ # @pytest.mark.parametrize('seqlen', [2048])
295
+ @pytest.mark.parametrize('dim', [64, 4096 + 32])
296
+ # @pytest.mark.parametrize('dim', [64])
297
+ def test_causal_conv1d_varlen(dim, seqlen, width, has_bias, silu_activation, itype):
298
+ device = "cuda"
299
+ rtol, atol = (3e-4, 1e-3) if itype == torch.float32 else (3e-3, 5e-3)
300
+ if itype == torch.bfloat16:
301
+ rtol, atol = 1e-2, 5e-2
302
+ rtolw, atolw = (1e-3, 1e-3)
303
+ # set seed
304
+ torch.random.manual_seed(seqlen + dim + width)
305
+ batch = 3
306
+ seqlens = []
307
+ for b in range(batch):
308
+ nsplits = torch.randint(1, 5, (1,)).item()
309
+ eos_pos = torch.randperm(seqlen - 1)[:nsplits].sort().values
310
+ seqlens.append(torch.diff(torch.cat([torch.tensor([-1]), eos_pos, torch.tensor([seqlen - 1])])).tolist())
311
+ assert sum(seqlens[-1]) == seqlen
312
+ assert all(s > 0 for s in seqlens[-1])
313
+ # Only support channel_last
314
+ x = rearrange(
315
+ torch.randn(batch, seqlen, 4096 + dim + 64, device=device, dtype=itype)[:, :, 4096:4096 + dim], "b s d -> b d s"
316
+ ).requires_grad_()
317
+ weight = torch.randn(dim, width, device=device, dtype=torch.float32, requires_grad=True)
318
+ if has_bias:
319
+ bias = torch.randn(dim, device=device, dtype=torch.float32, requires_grad=True)
320
+ else:
321
+ bias = None
322
+ seq_idx = torch.stack([torch.cat([torch.full((s,), i, dtype=torch.int32, device=device) for i, s in enumerate(sl)], dim=0)
323
+ for sl in seqlens], dim=0)
324
+ x_ref = x.detach().clone().requires_grad_()
325
+ weight_ref = weight.detach().clone().requires_grad_()
326
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
327
+ activation = None if not silu_activation else "silu"
328
+ out = causal_conv1d_fn(x, weight, bias, seq_idx=seq_idx, activation=activation)
329
+ out_ref = []
330
+ for b in range(batch):
331
+ out_ref_b = []
332
+ for x_s in torch.split(x_ref[[b]], seqlens[b], dim=2):
333
+ out_ref_b.append(causal_conv1d_ref(x_s, weight_ref, bias_ref, activation=activation))
334
+ out_ref.append(torch.cat(out_ref_b, dim=2))
335
+ out_ref = torch.cat(out_ref, dim=0)
336
+
337
+ print(f"Output max diff: {(out - out_ref).abs().max().item()}")
338
+ print(f"Output mean diff: {(out - out_ref).abs().mean().item()}")
339
+ assert torch.allclose(out, out_ref, rtol=rtol, atol=atol)
340
+
341
+ g = torch.randn_like(out)
342
+ out_ref.backward(g)
343
+ out.backward(g)
344
+
345
+ print(f"dx max diff: {(x.grad - x_ref.grad).abs().max().item()}")
346
+ print(f"dweight max diff: {(weight.grad - weight_ref.grad).abs().max().item()}")
347
+ if has_bias:
348
+ print(f"dbias max diff: {(bias.grad - bias_ref.grad).abs().max().item()}")
349
+
350
+ assert torch.allclose(x.grad, x_ref.grad.to(dtype=itype), rtol=rtol, atol=atol)
351
+ assert torch.allclose(weight.grad, weight_ref.grad, rtol=rtolw, atol=atolw)
352
+ if has_bias:
353
+ assert torch.allclose(bias.grad, bias_ref.grad, rtol=rtolw, atol=atolw)
torch-ext/causal_conv1d/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .causal_conv1d_interface import causal_conv1d_fn, causal_conv1d_update
2
+ from .causal_conv1d_varlen import causal_conv1d_varlen_states
3
+
4
+ __all__ = ["causal_conv1d_fn", "causal_conv1d_update", "causal_conv1d_varlen_states"]
torch-ext/causal_conv1d/causal_conv1d_interface.py ADDED
@@ -0,0 +1,242 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ from .cpp_functions import causal_conv1d_fwd_function, causal_conv1d_bwd_function, causal_conv1d_update_function
7
+
8
+
9
+ class CausalConv1dFn(torch.autograd.Function):
10
+ @staticmethod
11
+ def forward(
12
+ ctx,
13
+ x,
14
+ weight,
15
+ bias=None,
16
+ seq_idx=None,
17
+ initial_states=None,
18
+ return_final_states=False,
19
+ final_states_out=None,
20
+ activation=None,
21
+ ):
22
+ if activation not in [None, "silu", "swish"]:
23
+ raise NotImplementedError("activation must be None, silu, or swish")
24
+ if x.stride(2) != 1 and x.stride(1) != 1:
25
+ x = x.contiguous()
26
+ bias = bias.contiguous() if bias is not None else None
27
+ if seq_idx is not None:
28
+ assert (
29
+ initial_states is None
30
+ ), "initial_states must be None if seq_idx is not None"
31
+ assert (
32
+ not return_final_states
33
+ ), "If seq_idx is not None, we don't return final_states_out"
34
+ seq_idx = seq_idx.contiguous() if seq_idx is not None else None
35
+ if initial_states is not None and (
36
+ initial_states.stride(2) != 1 and initial_states.stride(1) != 1
37
+ ):
38
+ initial_states = initial_states.contiguous()
39
+ if return_final_states:
40
+ assert (
41
+ x.stride(1) == 1
42
+ ), "Only channel-last layout support returning final_states_out"
43
+ if final_states_out is not None:
44
+ assert (
45
+ final_states_out.stride(2) == 1 or final_states_out.stride(1) == 1
46
+ )
47
+ else:
48
+ batch, dim, seqlen = x.shape
49
+ width = weight.shape[1]
50
+ final_states_out = torch.empty(
51
+ batch, width - 1, dim, device=x.device, dtype=x.dtype
52
+ ).transpose(1, 2)
53
+ else:
54
+ final_states_out = None
55
+ ctx.activation = activation in ["silu", "swish"]
56
+ out = causal_conv1d_fwd_function(
57
+ x, weight, bias, seq_idx, initial_states, final_states_out, ctx.activation
58
+ )
59
+ ctx.save_for_backward(x, weight, bias, seq_idx, initial_states)
60
+ ctx.return_final_states = return_final_states
61
+ ctx.return_dinitial_states = (
62
+ initial_states is not None and initial_states.requires_grad
63
+ )
64
+ return out if not return_final_states else (out, final_states_out)
65
+
66
+ @staticmethod
67
+ def backward(ctx, dout, *args):
68
+ x, weight, bias, seq_idx, initial_states = ctx.saved_tensors
69
+ dfinal_states = args[0] if ctx.return_final_states else None
70
+ if dout.stride(2) != 1 and dout.stride(1) != 1:
71
+ dout = dout.contiguous()
72
+ # The kernel supports passing in a pre-allocated dx (e.g., in case we want to fuse the
73
+ # backward of conv1d with the backward of chunk).
74
+ # Here we just pass in None and dx will be allocated in the C++ code.
75
+ dx, dweight, dbias, dinitial_states = causal_conv1d_bwd_function(
76
+ x,
77
+ weight,
78
+ bias,
79
+ dout,
80
+ seq_idx,
81
+ initial_states,
82
+ dfinal_states,
83
+ None,
84
+ ctx.return_dinitial_states,
85
+ ctx.activation,
86
+ )
87
+ return (
88
+ dx,
89
+ dweight,
90
+ dbias if bias is not None else None,
91
+ None,
92
+ dinitial_states if initial_states is not None else None,
93
+ None,
94
+ None,
95
+ None,
96
+ )
97
+
98
+
99
+ def causal_conv1d_fn(
100
+ x,
101
+ weight,
102
+ bias=None,
103
+ seq_idx=None,
104
+ initial_states=None,
105
+ return_final_states=False,
106
+ final_states_out=None,
107
+ activation=None,
108
+ ):
109
+ """
110
+ x: (batch, dim, seqlen)
111
+ weight: (dim, width)
112
+ bias: (dim,)
113
+ seq_idx: (batch, seqlen)
114
+ initial_states: (batch, dim, width - 1)
115
+ final_states_out: (batch, dim, width - 1), to be written to
116
+ activation: either None or "silu" or "swish"
117
+
118
+ out: (batch, dim, seqlen)
119
+ """
120
+ return CausalConv1dFn.apply(
121
+ x,
122
+ weight,
123
+ bias,
124
+ seq_idx,
125
+ initial_states,
126
+ return_final_states,
127
+ final_states_out,
128
+ activation,
129
+ )
130
+
131
+
132
+ def causal_conv1d_ref(
133
+ x,
134
+ weight,
135
+ bias=None,
136
+ initial_states=None,
137
+ return_final_states=False,
138
+ final_states_out=None,
139
+ activation=None,
140
+ ):
141
+ """
142
+ x: (batch, dim, seqlen)
143
+ weight: (dim, width)
144
+ bias: (dim,)
145
+ initial_states: (batch, dim, width - 1)
146
+ final_states_out: (batch, dim, width - 1)
147
+
148
+ out: (batch, dim, seqlen)
149
+ """
150
+ if activation not in [None, "silu", "swish"]:
151
+ raise NotImplementedError("activation must be None, silu, or swish")
152
+ dtype_in = x.dtype
153
+ x = x.to(weight.dtype)
154
+ seqlen = x.shape[-1]
155
+ dim, width = weight.shape
156
+ if initial_states is None:
157
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=width - 1, groups=dim)
158
+ else:
159
+ x = torch.cat([initial_states, x], dim=-1)
160
+ out = F.conv1d(x, weight.unsqueeze(1), bias, padding=0, groups=dim)
161
+ out = out[..., :seqlen]
162
+ if return_final_states:
163
+ final_states = F.pad(x, (width - 1 - x.shape[-1], 0)).to(
164
+ dtype_in
165
+ ) # (batch, dim, width - 1)
166
+ if final_states_out is not None:
167
+ final_states_out.copy_(final_states)
168
+ else:
169
+ final_states_out = final_states
170
+ out = (out if activation is None else F.silu(out)).to(dtype=dtype_in)
171
+ return out if not return_final_states else (out, final_states_out)
172
+
173
+
174
+ def causal_conv1d_update(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None, conv_state_indices=None):
175
+ """
176
+ x: (batch, dim) or (batch, dim, seqlen)
177
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
178
+ weight: (dim, width)
179
+ bias: (dim,)
180
+ cache_seqlens: (batch,), dtype int32.
181
+ If not None, the conv_state is treated as a circular buffer.
182
+ The conv_state will be updated by copying x to the conv_state starting at the index
183
+ @cache_seqlens % state_len.
184
+ conv_state_indices: (batch,), dtype int32
185
+ If None, the conv_state is a larger tensor along the batch dim,
186
+ and we are selecting the batch coords specified by conv_state_indices.
187
+ Useful for a continuous batching scenario.
188
+
189
+ out: (batch, dim) or (batch, dim, seqlen)
190
+ """
191
+ if activation not in [None, "silu", "swish"]:
192
+ raise NotImplementedError("activation must be None, silu, or swish")
193
+ activation = activation in ["silu", "swish"]
194
+ unsqueeze = x.dim() == 2
195
+ if unsqueeze:
196
+ x = x.unsqueeze(-1)
197
+ out = causal_conv1d_update_function(
198
+ x, conv_state, weight, bias, activation, cache_seqlens, conv_state_indices
199
+ )
200
+ if unsqueeze:
201
+ out = out.squeeze(-1)
202
+ return out
203
+
204
+
205
+ def causal_conv1d_update_ref(x, conv_state, weight, bias=None, activation=None, cache_seqlens=None):
206
+ """
207
+ x: (batch, dim) or (batch, dim, seqlen)
208
+ conv_state: (batch, dim, state_len), where state_len >= width - 1
209
+ weight: (dim, width)
210
+ bias: (dim,)
211
+ cache_seqlens: (batch,), dtype int32.
212
+ If not None, the conv_state is treated as a circular buffer.
213
+ The conv_state will be updated by copying x to the conv_state starting at the index
214
+ @cache_seqlens % state_len before performing the convolution.
215
+
216
+ out: (batch, dim) or (batch, dim, seqlen)
217
+ """
218
+ if activation not in [None, "silu", "swish"]:
219
+ raise NotImplementedError("activation must be None, silu, or swish")
220
+ dtype_in = x.dtype
221
+ unsqueeze = x.dim() == 2
222
+ if unsqueeze:
223
+ x = x.unsqueeze(-1)
224
+ batch, dim, seqlen = x.shape
225
+ width = weight.shape[1]
226
+ state_len = conv_state.shape[-1]
227
+ assert conv_state.shape == (batch, dim, state_len)
228
+ assert weight.shape == (dim, width)
229
+ if cache_seqlens is None:
230
+ x_new = torch.cat([conv_state, x], dim=-1).to(weight.dtype) # (batch, dim, state_len + seqlen)
231
+ conv_state.copy_(x_new[:, :, -state_len:])
232
+ else:
233
+ width_idx = torch.arange(-(width - 1), 0, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
234
+ width_idx = torch.remainder(width_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
235
+ x_new = torch.cat([conv_state.gather(2, width_idx), x], dim=-1).to(weight.dtype)
236
+ copy_idx = torch.arange(seqlen, dtype=torch.long, device=x.device).unsqueeze(0) + cache_seqlens.unsqueeze(1)
237
+ copy_idx = torch.remainder(copy_idx, state_len).unsqueeze(1).expand(-1, dim, -1)
238
+ conv_state.scatter_(2, copy_idx, x)
239
+ out = F.conv1d(x_new, weight.unsqueeze(1), bias, padding=0, groups=dim)[:, :, -seqlen:]
240
+ if unsqueeze:
241
+ out = out.squeeze(-1)
242
+ return (out if activation is None else F.silu(out)).to(dtype=dtype_in)
torch-ext/causal_conv1d/causal_conv1d_varlen.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _causal_conv1d_varlen_states(
10
+ X,
11
+ CU_SEQLENS,
12
+ STATES,
13
+ state_len,
14
+ dim,
15
+ stride_x_seqlen, stride_x_dim,
16
+ stride_states_batch, stride_states_seqlen, stride_states_dim,
17
+ BLOCK_M: tl.constexpr,
18
+ BLOCK_N: tl.constexpr
19
+ ):
20
+ batch_idx = tl.program_id(2)
21
+ STATES += batch_idx * stride_states_batch
22
+ end_idx = tl.load(CU_SEQLENS + batch_idx + 1)
23
+ start_idx = tl.maximum(tl.load(CU_SEQLENS + batch_idx), end_idx - state_len)
24
+ rows = end_idx - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
25
+ cols = tl.program_id(0) * BLOCK_N + tl.arange(0, BLOCK_N)
26
+ x = tl.load(X + rows[:, None] * stride_x_seqlen + cols[None, :] * stride_x_dim,
27
+ mask=(rows[:, None] >= start_idx) & (cols[None, :] < dim),
28
+ other=0)
29
+ rows_states = state_len - (tl.program_id(1) + 1) * BLOCK_M + tl.arange(0, BLOCK_M)
30
+ tl.store(STATES + rows_states[:, None] * stride_states_seqlen + cols[None, :] * stride_states_dim,
31
+ x,
32
+ mask=(rows_states[:, None] >= 0) & (cols[None, :] < dim))
33
+
34
+
35
+ def causal_conv1d_varlen_states(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
36
+ """
37
+ Forward pass only, does not support backward pass.
38
+ Parameters:
39
+ x: (total_tokens, dim)
40
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
41
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
42
+ If some of those elements belong to a different sequence, the value of the states will be zero.
43
+ Return:
44
+ states: (batch, dim, state_len)
45
+ """
46
+ _, dim = x.shape
47
+ batch = cu_seqlens.shape[0] - 1
48
+ cu_seqlens = cu_seqlens.contiguous()
49
+ states = torch.empty(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
50
+ BLOCK_M = min(triton.next_power_of_2(state_len), 16)
51
+ BLOCK_N = min(triton.next_power_of_2(dim), 256)
52
+ grid = (triton.cdiv(dim, BLOCK_N), triton.cdiv(state_len, BLOCK_M), batch)
53
+ with torch.cuda.device(x.device.index):
54
+ _causal_conv1d_varlen_states[grid](
55
+ x,
56
+ cu_seqlens,
57
+ states,
58
+ state_len,
59
+ dim,
60
+ x.stride(0), x.stride(1),
61
+ states.stride(0), states.stride(2), states.stride(1),
62
+ BLOCK_M=BLOCK_M, BLOCK_N=BLOCK_N
63
+ )
64
+ return states
65
+
66
+
67
+ def causal_conv1d_varlen_states_ref(x: Tensor, cu_seqlens: Tensor, state_len: int) -> Tensor:
68
+ """
69
+ Forward pass only, does not support backward pass.
70
+ Parameters:
71
+ x: (total_tokens, dim)
72
+ cu_seqlens: (batch + 1), must already be sorted. The cumulative sum of the sequence lengths, starting from 0.
73
+ state_len: int. For each cu_seqlens, how many elements from x should be copied to the state.
74
+ If some of those elements belong to a different sequence, the value of the states will be zero.
75
+ Return:
76
+ states: (batch, dim, state_len)
77
+ """
78
+ _, dim = x.shape
79
+ batch = cu_seqlens.shape[0] - 1
80
+ cu_seqlens = cu_seqlens.contiguous()
81
+ states = torch.zeros(batch, state_len, dim, dtype=x.dtype, device=x.device).transpose(1, 2)
82
+ for i in range(batch):
83
+ end_idx = cu_seqlens[i + 1]
84
+ start_idx = torch.maximum(cu_seqlens[i], end_idx - state_len)
85
+ states[i, :, -(end_idx - start_idx):] = x[start_idx:end_idx].T
86
+ return states
torch-ext/causal_conv1d/cpp_functions.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import torch
4
+
5
+ from ._ops import ops
6
+
7
+ def causal_conv1d_fwd_function(
8
+ x: torch.Tensor,
9
+ weight: torch.Tensor,
10
+ bias: torch.Tensor | None,
11
+ seq_idx: torch.Tensor | None,
12
+ initial_states: torch.Tensor | None,
13
+ final_states_out: torch.Tensor | None,
14
+ silu_activation: bool,
15
+ ) -> torch.Tensor:
16
+ out = torch.empty_like(x)
17
+ ops.causal_conv1d_fwd(
18
+ x=x,
19
+ weight=weight,
20
+ bias=bias,
21
+ seq_idx=seq_idx,
22
+ initial_states=initial_states,
23
+ out=out,
24
+ final_states_out=final_states_out,
25
+ silu_activation=silu_activation,
26
+ )
27
+ return out
28
+
29
+
30
+ def causal_conv1d_bwd_function(
31
+ x: torch.Tensor,
32
+ weight: torch.Tensor,
33
+ bias: torch.Tensor | None,
34
+ dout: torch.Tensor,
35
+ seq_idx: torch.Tensor | None,
36
+ initial_states: torch.Tensor | None,
37
+ dfinal_states: torch.Tensor | None,
38
+ dx: torch.Tensor | None,
39
+ return_dinitial_states: torch.Tensor,
40
+ silu_activation: bool,
41
+ ) -> tuple[torch.Tensor | None]:
42
+ batch_size, dim = x.size()[:2]
43
+ width = weight.size(-1)
44
+
45
+ if dx is None:
46
+ dx = torch.empty_like(x)
47
+ dweight = torch.zeros_like(weight, dtype=torch.float32)
48
+ dbias = None
49
+ if bias is not None:
50
+ dbias = torch.zeros_like(bias, dtype=torch.float32)
51
+ dinitial_states = None
52
+ if return_dinitial_states:
53
+ dinitial_states = torch.empty(batch_size, width - 1, dim, device=x.device, dtype=x.dtype).transpose(1, 2)
54
+
55
+ ops.causal_conv1d_bwd(
56
+ x=x,
57
+ weight=weight,
58
+ bias=bias,
59
+ dout=dout,
60
+ seq_idx=seq_idx,
61
+ initial_states=initial_states,
62
+ dfinal_states=dfinal_states,
63
+ dx=dx,
64
+ dweight=dweight,
65
+ dbias=dbias,
66
+ dinitial_states=dinitial_states,
67
+ silu_activation=silu_activation,
68
+ )
69
+
70
+ dweight = dweight.type_as(weight)
71
+ if dbias is not None:
72
+ dbias = dbias.type_as(bias)
73
+ return dx, dweight, dbias, dinitial_states
74
+
75
+
76
+ def causal_conv1d_update_function(
77
+ x: torch.Tensor,
78
+ conv_state: torch.Tensor,
79
+ weight: torch.Tensor,
80
+ bias: torch.Tensor | None,
81
+ silu_activation: bool,
82
+ cache_seqlens: torch.Tensor | None,
83
+ conv_state_indices: torch.Tensor | None,
84
+ ) -> torch.Tensor:
85
+ out = torch.empty_like(x)
86
+ ops.causal_conv1d_update(
87
+ x=x,
88
+ conv_state=conv_state,
89
+ weight=weight,
90
+ bias=bias,
91
+ out=out,
92
+ silu_activation=silu_activation,
93
+ cache_seqlens=cache_seqlens,
94
+ conv_state_indices=conv_state_indices,
95
+ )
96
+ return out
torch-ext/pytorch_shim.h ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/library.h>
4
+
5
+ /**
6
+ * Unforunately, the type signatures of the flash_attn ops are not compatible
7
+ * with the PyTorch library bindings. To get around that we use
8
+ * `make_pytorch_shim` which creates a lambda that exponses the API using
9
+ * PyTorch compatible types to the types, then converts them to the types
10
+ * expected by the flash_attn ops. This shims allows us to make minimal changes
11
+ * to `flash_api.cpp` making it easier to synchronize with upstream changes.
12
+ *
13
+ * The `pytorch_library_compatible_type` struct is used to map from the
14
+ * flash_attn ops types to a PyTorch library compatible one. The main issues is
15
+ * that the following types are not support by PyTorch libary bindings:
16
+ * - `int`
17
+ * - `float`
18
+ * - `std::optional<T> &`
19
+ * - `std::optional<const at::Tensor> &`
20
+ * So we convert them to (respectively):
21
+ * - `int64_t`
22
+ * - `double`
23
+ * - `const std::optional<T>&`
24
+ * - `const std::optional<at::Tensor>&`
25
+ */
26
+
27
+ template<typename T>
28
+ struct pytorch_library_compatible_type {
29
+ using type = T;
30
+ static T convert_from_type(T arg) { return arg; }
31
+ };
32
+
33
+ template<typename T>
34
+ using pytorch_library_compatible_type_t = \
35
+ typename pytorch_library_compatible_type<T>::type;
36
+
37
+ template<typename T>
38
+ T convert_from_pytorch_compatible_type(pytorch_library_compatible_type_t<T> arg)
39
+ { return pytorch_library_compatible_type<T>::convert_from_type(arg); }
40
+
41
+ // Map `std::optional<T> &` -> `const std::optional<T>&`
42
+ // (NOTE: this is bit unsafe but non of the ops in flash_attn mutate
43
+ // the optional container)
44
+ template<typename T>
45
+ struct pytorch_library_compatible_type<std::optional<T> &> {
46
+ using type = const std::optional<T>&;
47
+ static std::optional<T>& convert_from_type(const std::optional<T> &arg) {
48
+ return const_cast<std::optional<T>&>(arg);
49
+ }
50
+ };
51
+
52
+ // Map `std::optional<T>` ->
53
+ // `std::optional<pytorch_library_compatible_type_t<T>>`
54
+ // (NOTE: tested for `std::optional<int>` -> `std::optional<int64_t>`)
55
+ template<typename T>
56
+ struct pytorch_library_compatible_type<std::optional<T>> {
57
+ using type = std::optional<pytorch_library_compatible_type_t<T>>;
58
+ static std::optional<pytorch_library_compatible_type_t<T>> convert_from_type(std::optional<T> arg) {
59
+ return arg;
60
+ }
61
+ };
62
+
63
+ // Map `std::optional<const at::Tensor>&` -> `const std::optional<at::Tensor>&`
64
+ template<>
65
+ struct pytorch_library_compatible_type<std::optional<const at::Tensor> &> {
66
+ using type = const std::optional<at::Tensor>&;
67
+ static std::optional<const at::Tensor>& convert_from_type(
68
+ const std::optional<at::Tensor> &arg) {
69
+ return const_cast<std::optional<const at::Tensor>&>(
70
+ reinterpret_cast<const std::optional<const at::Tensor>&>(arg));
71
+ }
72
+ };
73
+
74
+ // Map `int` -> `int64_t`
75
+ template<> struct pytorch_library_compatible_type<int> {
76
+ using type = int64_t;
77
+ static int convert_from_type(int64_t arg) {
78
+ TORCH_CHECK(arg <= std::numeric_limits<int>::max(),
79
+ "int64_t value is too large to be converted to int");
80
+ TORCH_CHECK(arg >= std::numeric_limits<int>::min(),
81
+ "int64_t value is too small to be converted to int");
82
+ return arg;
83
+ }
84
+ };
85
+
86
+ // Map `float` -> `double`
87
+ template<> struct pytorch_library_compatible_type<float> {
88
+ using type = double;
89
+ static float convert_from_type(double arg) {
90
+ TORCH_CHECK(std::abs(arg) <= std::numeric_limits<float>::max(),
91
+ "double value is too large to be converted to float");
92
+ return arg;
93
+ }
94
+ };
95
+
96
+ //
97
+ // Shim Utils
98
+ //
99
+
100
+ template <typename Ret, typename... Args>
101
+ auto make_pytorch_shim(Ret(*fun)(Args... args)){
102
+ return [fun](pytorch_library_compatible_type_t<Args>... args) {
103
+ return fun(convert_from_pytorch_compatible_type<Args>(args)...);
104
+ };
105
+ }
torch-ext/torch_binding.cpp ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #include <torch/library.h>
2
+
3
+ #include "registration.h"
4
+
5
+ #include "pytorch_shim.h"
6
+ #include "torch_binding.h"
7
+
8
+ TORCH_LIBRARY_EXPAND(TORCH_EXTENSION_NAME, ops) {
9
+ ops.def(
10
+ "causal_conv1d_fwd("
11
+ " Tensor x, Tensor weight, Tensor? bias, Tensor? seq_idx,"
12
+ " Tensor? initial_states, Tensor! out, Tensor!? final_states_out,"
13
+ " bool silu_activation) -> ()");
14
+ ops.impl("causal_conv1d_fwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_fwd));
15
+
16
+ ops.def(
17
+ "causal_conv1d_bwd("
18
+ " Tensor x, Tensor weight, Tensor? bias, Tensor! dout,"
19
+ " Tensor? seq_idx, Tensor? initial_states, Tensor? dfinal_states,"
20
+ " Tensor! dx, Tensor! dweight, Tensor!? dbias,"
21
+ " Tensor!? dinitial_states, bool silu_activation) -> ()");
22
+ ops.impl("causal_conv1d_bwd", torch::kCUDA, make_pytorch_shim(&causal_conv1d_bwd));
23
+
24
+ ops.def(
25
+ "causal_conv1d_update("
26
+ " Tensor x, Tensor conv_state, Tensor weight, Tensor? bias,"
27
+ " Tensor! out, bool silu_activation, Tensor? cache_seqlens,"
28
+ " Tensor? conv_state_indices) -> ()");
29
+ ops.impl("causal_conv1d_update", torch::kCUDA, make_pytorch_shim(&causal_conv1d_update));
30
+ }
31
+
32
+ REGISTER_EXTENSION(TORCH_EXTENSION_NAME)
torch-ext/torch_binding.h ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #pragma once
2
+
3
+ #include <torch/torch.h>
4
+
5
+ void
6
+ causal_conv1d_fwd(const at::Tensor &x,
7
+ const at::Tensor &weight,
8
+ const c10::optional<at::Tensor> &bias_,
9
+ const c10::optional<at::Tensor> &seq_idx_,
10
+ const c10::optional<at::Tensor> &initial_states_,
11
+ at::Tensor &out,
12
+ c10::optional<at::Tensor> &final_states_out_,
13
+ bool silu_activation);
14
+
15
+ void
16
+ causal_conv1d_bwd(const at::Tensor &x,
17
+ const at::Tensor &weight,
18
+ const c10::optional<at::Tensor> &bias_,
19
+ at::Tensor &dout,
20
+ const c10::optional<at::Tensor> &seq_idx_,
21
+ const c10::optional<at::Tensor> &initial_states_,
22
+ const c10::optional<at::Tensor> &dfinal_states_,
23
+ at::Tensor &dx,
24
+ at::Tensor &dweight,
25
+ c10::optional<at::Tensor> &dbias_,
26
+ c10::optional<at::Tensor> &dinitial_states_,
27
+ bool silu_activation);
28
+
29
+ void
30
+ causal_conv1d_update(const at::Tensor &x,
31
+ const at::Tensor &conv_state,
32
+ const at::Tensor &weight,
33
+ const c10::optional<at::Tensor> &bias_,
34
+ at::Tensor &out,
35
+ bool silu_activation,
36
+ const c10::optional<at::Tensor> &cache_seqlens_,
37
+ const c10::optional<at::Tensor> &conv_state_indices_
38
+ );
39
+