danieldk HF Staff commited on
Commit
bbbdefe
·
1 Parent(s): 9b61b27

Sync with upstream, add tests

Browse files
flake.lock CHANGED
@@ -73,11 +73,11 @@
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
- "lastModified": 1750234878,
77
- "narHash": "sha256-q9DRC9zdpzUf88qqg1qbhP1qgJbE2cMtn8oUmosuyT8=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
- "rev": "c7132f90763d756da3e77da62e01be0a4546dc57",
81
  "type": "github"
82
  },
83
  "original": {
@@ -98,11 +98,11 @@
98
  ]
99
  },
100
  "locked": {
101
- "lastModified": 1750409351,
102
- "narHash": "sha256-xkzrwee77LrBDtwNNihBkYbY7yUwdOv0/4+J3B5xCZE=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
- "rev": "9e61fba877153bffa6eaff023243fd81220c0eea",
106
  "type": "github"
107
  },
108
  "original": {
@@ -113,17 +113,17 @@
113
  },
114
  "nixpkgs": {
115
  "locked": {
116
- "lastModified": 1747820358,
117
- "narHash": "sha256-fTqsZsUX6M3yeEvgyQvXcbGmT2CaRVyVwsi8eK29Oj4=",
118
- "owner": "danieldk",
119
  "repo": "nixpkgs",
120
- "rev": "d3c1681180717528068082103bf323147de6ab0b",
121
  "type": "github"
122
  },
123
  "original": {
124
- "owner": "danieldk",
125
- "ref": "cudatoolkit-12.9-kernel-builder",
126
  "repo": "nixpkgs",
 
127
  "type": "github"
128
  }
129
  },
 
73
  "nixpkgs": "nixpkgs"
74
  },
75
  "locked": {
76
+ "lastModified": 1754038838,
77
+ "narHash": "sha256-oHigCT4z0ayyLyEuxdZooSXRAZP8lfOkZHzY1lx1U50=",
78
  "owner": "huggingface",
79
  "repo": "hf-nix",
80
+ "rev": "336f781fa284e193baa3d4c3ce3f95fb34e9ffad",
81
  "type": "github"
82
  },
83
  "original": {
 
98
  ]
99
  },
100
  "locked": {
101
+ "lastModified": 1756320464,
102
+ "narHash": "sha256-x9LI4h87/Z9UgTQjgeG0fRcdeXl91xIqBlTauGKZM70=",
103
  "owner": "huggingface",
104
  "repo": "kernel-builder",
105
+ "rev": "b4accba4496b28faef19a0487fbcf9686b14e2ef",
106
  "type": "github"
107
  },
108
  "original": {
 
113
  },
114
  "nixpkgs": {
115
  "locked": {
116
+ "lastModified": 1752785354,
117
+ "narHash": "sha256-Y33ryUz7MPqKrZwlbQcsYCUz2jAJCacRf8jbs0tYUlA=",
118
+ "owner": "nixos",
119
  "repo": "nixpkgs",
120
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
121
  "type": "github"
122
  },
123
  "original": {
124
+ "owner": "nixos",
 
125
  "repo": "nixpkgs",
126
+ "rev": "d38025438a6ee456758dc03188ca6873a415463b",
127
  "type": "github"
128
  }
129
  },
flake.nix CHANGED
@@ -13,5 +13,8 @@
13
  kernel-builder.lib.genFlakeOutputs {
14
  path = ./.;
15
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
 
 
 
16
  };
17
  }
 
13
  kernel-builder.lib.genFlakeOutputs {
14
  path = ./.;
15
  rev = self.shortRev or self.dirtyShortRev or self.lastModifiedDate;
16
+ # Import-time autotune.
17
+ doGetKernelCheck = false;
18
+ pythonCheckInputs = pkgs: with pkgs; [ einops ];
19
  };
20
  }
tests/test_layer_norm.py ADDED
@@ -0,0 +1,373 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024, Tri Dao.
2
+
3
+ import pytest
4
+ import torch
5
+ import torch.nn.functional as F
6
+ from einops import rearrange, repeat
7
+
8
+ from triton_layer_norm import (
9
+ layer_norm_fn,
10
+ layer_norm_linear_fn,
11
+ )
12
+ from triton_layer_norm.layer_norm import layer_norm_ref, rms_norm_ref
13
+
14
+
15
+ is_sm8x = torch.cuda.get_device_capability("cuda")[0] >= 8
16
+
17
+
18
+ # @pytest.mark.parametrize("zero_centered_weight", [False, True])
19
+ @pytest.mark.parametrize("zero_centered_weight", [False])
20
+ @pytest.mark.parametrize("has_weight1", [False, True])
21
+ # @pytest.mark.parametrize("has_weight1", [False])
22
+ @pytest.mark.parametrize("has_x1", [False, True])
23
+ # @pytest.mark.parametrize("has_x1", [False])
24
+ @pytest.mark.parametrize("has_rowscale", [False, True])
25
+ # @pytest.mark.parametrize("has_rowscale", [False])
26
+ @pytest.mark.parametrize("dropout_p", [0.0, 0.27])
27
+ # @pytest.mark.parametrize("dropout_p", [0.0])
28
+ @pytest.mark.parametrize("prenorm", [True, False])
29
+ # @pytest.mark.parametrize("prenorm", [True])
30
+ @pytest.mark.parametrize("is_rms_norm", [False, True])
31
+ # @pytest.mark.parametrize("is_rms_norm", [True])
32
+ @pytest.mark.parametrize("has_residual", [True, False])
33
+ # @pytest.mark.parametrize("has_residual", [True])
34
+ @pytest.mark.parametrize(
35
+ "weight_dtype", [torch.float32, torch.float16] + ([torch.bfloat16] if is_sm8x else [])
36
+ )
37
+ # @pytest.mark.parametrize("weight_dtype", [torch.float32])
38
+ @pytest.mark.parametrize(
39
+ "input_dtype,residual_dtype",
40
+ [(torch.float16, torch.float16), (torch.float16, torch.float32), (torch.float32, torch.float32)]
41
+ + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
42
+ )
43
+ # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.float16, torch.float16)])
44
+ @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000, 4096])
45
+ # @pytest.mark.parametrize("hidden_size", [1024])
46
+ def test_layer_norm(
47
+ hidden_size,
48
+ input_dtype,
49
+ residual_dtype,
50
+ weight_dtype,
51
+ has_residual,
52
+ is_rms_norm,
53
+ prenorm,
54
+ dropout_p,
55
+ has_rowscale,
56
+ has_x1,
57
+ has_weight1,
58
+ zero_centered_weight,
59
+ ):
60
+ if has_rowscale and has_x1:
61
+ pytest.skip("Not supported")
62
+ device = "cuda"
63
+ if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
64
+ atol = 5e-2
65
+ elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
66
+ atol = 1e-2
67
+ else:
68
+ atol = 1e-4
69
+ # set seed
70
+ torch.random.manual_seed(0)
71
+ batch_size = 8
72
+ seqlen = 512
73
+ layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
74
+ allclose = (
75
+ # Sometimes x0_pt.grad is NaN
76
+ lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
77
+ <= 2 * (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() + atol
78
+ or (
79
+ # Sometimes x_pt and x_ref are the same (e.g. bfloat16) so we want to perturb is a bit
80
+ # by multiply and divide by 0.3
81
+ (x_pt[~x_pt.isnan()] - x_ref[~x_pt.isnan()]).abs().max() == 0.0
82
+ and (x - x_ref).abs().max()
83
+ <= 2 * (x_pt[~x_pt.isnan()] * 0.3 / 0.3 - x_ref[~x_pt.isnan()]).abs().max() + atol
84
+ )
85
+ )
86
+ x0 = torch.randn(
87
+ batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
88
+ )
89
+ x0_pt = x0.detach().clone().requires_grad_()
90
+ x0_ref = x0.detach().clone().requires_grad_()
91
+ if has_residual:
92
+ res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
93
+ res_pt = res.detach().clone().requires_grad_()
94
+ res_ref = res.detach().clone().requires_grad_()
95
+ else:
96
+ res, res_pt, res_ref = None, None, None
97
+ weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
98
+ if not is_rms_norm:
99
+ bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
100
+ else:
101
+ bias = None
102
+ weight_pt = weight.detach().clone().requires_grad_()
103
+ weight_ref = weight.detach().clone().requires_grad_()
104
+ bias_pt = bias.detach().clone().requires_grad_() if bias is not None else None
105
+ bias_ref = bias.detach().clone().requires_grad_() if bias is not None else None
106
+ if has_x1:
107
+ x1 = torch.randn_like(x0, dtype=input_dtype, requires_grad=True)
108
+ x1_pt = x1.detach().clone().requires_grad_()
109
+ x1_ref = x1.detach().clone().requires_grad_()
110
+ else:
111
+ x1, x1_pt, x1_ref = None, None, None
112
+ if has_weight1:
113
+ weight1 = torch.randn(
114
+ hidden_size, device=device, dtype=weight_dtype, requires_grad=True
115
+ )
116
+ weight1_pt = weight1.detach().clone().requires_grad_()
117
+ weight1_ref = weight1.detach().clone().requires_grad_()
118
+ if not is_rms_norm:
119
+ bias1 = torch.randn(
120
+ hidden_size, device=device, dtype=weight_dtype, requires_grad=True
121
+ )
122
+ else:
123
+ bias1 = None
124
+ bias1_pt = bias1.detach().clone().requires_grad_() if bias1 is not None else None
125
+ bias1_ref = bias1.detach().clone().requires_grad_() if bias1 is not None else None
126
+ else:
127
+ weight1, weight1_pt, weight1_ref = None, None, None
128
+ bias1, bias1_pt, bias1_ref = None, None, None
129
+
130
+ rowscale = (
131
+ torch.randn(batch_size, seqlen, dtype=input_dtype, device=device)
132
+ if has_rowscale
133
+ else None
134
+ )
135
+
136
+ residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
137
+ out, *rest = layer_norm_fn(
138
+ x0,
139
+ weight,
140
+ bias,
141
+ residual=res,
142
+ x1=x1,
143
+ weight1=weight1,
144
+ bias1=bias1,
145
+ eps=1e-6,
146
+ dropout_p=dropout_p,
147
+ rowscale=rowscale,
148
+ prenorm=prenorm,
149
+ residual_in_fp32=residual_in_fp32,
150
+ zero_centered_weight=zero_centered_weight,
151
+ is_rms_norm=is_rms_norm,
152
+ return_dropout_mask=True,
153
+ )
154
+ dropout_mask = rest[-2] if dropout_p > 0.0 else None
155
+ dropout_mask1 = rest[-1] if dropout_p > 0.0 and x1 is not None else None
156
+ out_pt = layer_norm_ref_fn(
157
+ x0_pt,
158
+ weight_pt,
159
+ bias_pt,
160
+ residual=res_pt,
161
+ x1=x1_pt,
162
+ weight1=weight1_pt,
163
+ bias1=bias1_pt,
164
+ eps=1e-6,
165
+ dropout_p=dropout_p,
166
+ rowscale=rowscale,
167
+ prenorm=prenorm,
168
+ zero_centered_weight=zero_centered_weight,
169
+ dropout_mask=dropout_mask,
170
+ dropout_mask1=dropout_mask1,
171
+ )
172
+ out_ref = layer_norm_ref_fn(
173
+ x0_ref,
174
+ weight_ref,
175
+ bias_ref,
176
+ residual=res_ref,
177
+ x1=x1_ref,
178
+ weight1=weight1_ref,
179
+ bias1=bias1_ref,
180
+ eps=1e-6,
181
+ dropout_p=dropout_p,
182
+ rowscale=rowscale,
183
+ prenorm=prenorm,
184
+ zero_centered_weight=zero_centered_weight,
185
+ dropout_mask=dropout_mask,
186
+ dropout_mask1=dropout_mask1,
187
+ upcast=True,
188
+ )
189
+ if not has_weight1:
190
+ if prenorm:
191
+ residual = rest[0]
192
+ out_pt, residual_pt = out_pt
193
+ out_ref, residual_ref = out_ref
194
+ out1, out1_pt, out1_ref = None, None, None
195
+ else:
196
+ out1 = rest.pop(0)
197
+ if prenorm:
198
+ residual = rest[0]
199
+ out_pt, out1_pt, residual_pt = out_pt
200
+ out_ref, out1_ref, residual_ref = out_ref
201
+ else:
202
+ out_pt, out1_pt = out_pt
203
+ out_ref, out1_ref = out_ref
204
+ assert out.dtype == input_dtype
205
+ if prenorm:
206
+ assert residual.dtype == residual_dtype
207
+ assert allclose(residual, residual_pt, residual_ref)
208
+ assert allclose(out, out_pt, out_ref)
209
+ if out1 is not None:
210
+ assert out1.dtype == input_dtype
211
+ assert allclose(out1, out1_pt, out1_ref)
212
+ if dropout_mask is not None:
213
+ dropout_fraction = 1.0 - dropout_mask.float().mean()
214
+ assert abs(dropout_fraction - dropout_p) < 0.01
215
+ if dropout_mask1 is not None:
216
+ dropout_fraction = 1.0 - dropout_mask1.float().mean()
217
+ assert abs(dropout_fraction - dropout_p) < 0.01
218
+ assert not torch.equal(dropout_mask, dropout_mask1)
219
+
220
+ g = torch.randn_like(out) / batch_size
221
+ if has_weight1:
222
+ out = out * F.gelu(out1)
223
+ out_pt = out_pt * F.gelu(out1_pt)
224
+ out_ref = out_ref * F.gelu(out1_ref)
225
+ if not prenorm:
226
+ out.backward(g)
227
+ out_pt.backward(g)
228
+ out_ref.backward(g)
229
+ else:
230
+ (out * F.sigmoid(residual)).backward(g)
231
+ (out_pt * F.sigmoid(residual_pt)).backward(g)
232
+ (out_ref * F.sigmoid(residual_ref.to(dtype=residual_dtype))).backward(g)
233
+ assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
234
+ if has_residual:
235
+ assert allclose(res.grad, res_pt.grad, res_ref.grad)
236
+ if has_x1:
237
+ assert allclose(x1.grad, x1_pt.grad, x1_ref.grad)
238
+ assert allclose(weight.grad, weight_pt.grad, weight_ref.grad)
239
+ if bias is not None:
240
+ assert allclose(bias.grad, bias_pt.grad, bias_ref.grad)
241
+ if has_weight1:
242
+ assert allclose(weight1.grad, weight1_pt.grad, weight1_ref.grad)
243
+ if bias1 is not None:
244
+ assert allclose(bias1.grad, bias1_pt.grad, bias1_ref.grad)
245
+
246
+
247
+ @pytest.mark.parametrize("prenorm", [True, False])
248
+ # @pytest.mark.parametrize("prenorm", [True])
249
+ @pytest.mark.parametrize("is_rms_norm", [False, True])
250
+ # @pytest.mark.parametrize("is_rms_norm", [True])
251
+ @pytest.mark.parametrize("has_residual", [True, False])
252
+ # @pytest.mark.parametrize("has_residual", [False])
253
+ @pytest.mark.parametrize("weight_dtype", [torch.float32])
254
+ @pytest.mark.parametrize(
255
+ "input_dtype,residual_dtype",
256
+ [(torch.float16, torch.float16), (torch.float16, torch.float32)]
257
+ + ([(torch.bfloat16, torch.bfloat16), (torch.bfloat16, torch.float32)] if is_sm8x else []),
258
+ )
259
+ # @pytest.mark.parametrize("input_dtype,residual_dtype", [(torch.bfloat16, torch.float32)])
260
+ @pytest.mark.parametrize("hidden_size", [192, 2048, 2560, 3000])
261
+ # @pytest.mark.parametrize("hidden_size", [256])
262
+ def test_layer_norm_linear(
263
+ hidden_size, input_dtype, residual_dtype, weight_dtype, has_residual, is_rms_norm, prenorm
264
+ ):
265
+ device = "cuda"
266
+ if any(x == torch.bfloat16 for x in [input_dtype, residual_dtype, weight_dtype]):
267
+ atol = 5e-2
268
+ elif any(x == torch.float16 for x in [input_dtype, residual_dtype, weight_dtype]):
269
+ atol = 1e-2
270
+ else:
271
+ atol = 1e-4
272
+ # set seed
273
+ torch.random.manual_seed(0)
274
+ batch_size = 4
275
+ seqlen = 512
276
+ # batch_size = 1
277
+ # seqlen = 1
278
+ layer_norm_ref_fn = layer_norm_ref if not is_rms_norm else rms_norm_ref
279
+ allclose = (
280
+ lambda x, x_pt, x_ref, atol=atol: (x - x_ref).abs().max()
281
+ <= 2 * (x_pt - x_ref).abs().max() + atol
282
+ )
283
+ x0 = torch.randn(
284
+ batch_size, seqlen, hidden_size, device=device, dtype=input_dtype, requires_grad=True
285
+ )
286
+ x0_pt = x0.detach().clone().requires_grad_()
287
+ x0_ref = x0.detach().clone().requires_grad_()
288
+ if has_residual:
289
+ res = torch.randn_like(x0, dtype=residual_dtype, requires_grad=True)
290
+ res_pt = res.detach().clone().requires_grad_()
291
+ res_ref = res.detach().clone().requires_grad_()
292
+ else:
293
+ res, res_pt, res_ref = None, None, None
294
+ norm_weight = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
295
+ if not is_rms_norm:
296
+ norm_bias = torch.randn(hidden_size, device=device, dtype=weight_dtype, requires_grad=True)
297
+ else:
298
+ norm_bias = None
299
+ norm_weight_pt = norm_weight.detach().clone().requires_grad_()
300
+ norm_weight_ref = norm_weight.detach().clone().requires_grad_()
301
+ norm_bias_pt = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
302
+ norm_bias_ref = norm_bias.detach().clone().requires_grad_() if norm_bias is not None else None
303
+ linear_weight = torch.empty(
304
+ 2 * hidden_size, hidden_size, device=device, dtype=weight_dtype, requires_grad=True
305
+ )
306
+ torch.nn.init.xavier_uniform_(linear_weight)
307
+ if not is_rms_norm:
308
+ linear_bias = torch.randn(
309
+ 2 * hidden_size, device=device, dtype=weight_dtype, requires_grad=True
310
+ )
311
+ else:
312
+ linear_bias = None
313
+ linear_weight_pt = linear_weight.detach().clone().requires_grad_()
314
+ linear_weight_ref = linear_weight.detach().clone().requires_grad_()
315
+ linear_bias_pt = (
316
+ linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
317
+ )
318
+ linear_bias_ref = (
319
+ linear_bias.detach().clone().requires_grad_() if linear_bias is not None else None
320
+ )
321
+
322
+ residual_in_fp32 = (not has_residual) and residual_dtype == torch.float32
323
+ with torch.autocast(device_type="cuda", dtype=input_dtype):
324
+ out, *rest = layer_norm_linear_fn(
325
+ x0,
326
+ norm_weight,
327
+ norm_bias,
328
+ linear_weight,
329
+ linear_bias,
330
+ residual=res,
331
+ eps=1e-6,
332
+ prenorm=prenorm,
333
+ residual_in_fp32=residual_in_fp32,
334
+ is_rms_norm=is_rms_norm,
335
+ )
336
+ out_pt, *rest_pt = layer_norm_ref_fn(
337
+ x0_pt, norm_weight_pt, norm_bias_pt, residual=res_pt, eps=1e-6, prenorm=prenorm
338
+ )
339
+ with torch.autocast(device_type="cuda", dtype=input_dtype):
340
+ out_pt = F.linear(out_pt, linear_weight_pt, linear_bias_pt)
341
+ out_ref, *rest_ref = layer_norm_ref_fn(
342
+ x0_ref,
343
+ norm_weight_ref,
344
+ norm_bias_ref,
345
+ residual=res_ref,
346
+ eps=1e-6,
347
+ prenorm=prenorm,
348
+ upcast=True,
349
+ )
350
+ out_ref = F.linear(out_ref.to(linear_weight_ref.dtype), linear_weight_ref, linear_bias_ref)
351
+ if prenorm:
352
+ residual = rest[0]
353
+ residual_pt = rest_pt[0]
354
+ residual_ref = rest_ref[0]
355
+ assert out.dtype == input_dtype
356
+ if prenorm:
357
+ assert residual.dtype == residual_dtype
358
+ assert allclose(residual, residual_pt, residual_ref)
359
+ assert allclose(out, out_pt, out_ref)
360
+
361
+ g = torch.randn_like(out) / batch_size
362
+ out.backward(g)
363
+ out_pt.backward(g)
364
+ out_ref.backward(g)
365
+ assert allclose(x0.grad, x0_pt.grad, x0_ref.grad)
366
+ if has_residual:
367
+ assert allclose(res.grad, res_pt.grad, res_ref.grad)
368
+ assert allclose(norm_weight.grad, norm_weight_pt.grad, norm_weight_ref.grad)
369
+ if norm_bias is not None:
370
+ assert allclose(norm_bias.grad, norm_bias_pt.grad, norm_bias_ref.grad)
371
+ assert allclose(linear_weight.grad, linear_weight_pt.grad, linear_weight_ref.grad)
372
+ if linear_bias is not None:
373
+ assert allclose(linear_bias.grad, linear_bias_pt.grad, linear_bias_ref.grad)
torch-ext/triton_layer_norm/__init__.py CHANGED
@@ -25,6 +25,7 @@ def layer_norm(
25
  rowscale=None,
26
  prenorm: bool = False,
27
  residual_in_fp32: bool = False,
 
28
  is_rms_norm: bool = False,
29
  return_dropout_mask: bool = False,
30
  out: Optional[torch.Tensor] = None,
@@ -61,6 +62,8 @@ def layer_norm(
61
  If True, returns both the normalized output and the unnormalized input+residual.
62
  residual_in_fp32 (`bool`, *optional*, defaults to False):
63
  If True, performs the residual connection in FP32 precision.
 
 
64
  is_rms_norm (`bool`, *optional*, defaults to False):
65
  If True, uses RMS normalization instead of layer normalization.
66
  return_dropout_mask (`bool`, *optional*, defaults to False):
 
25
  rowscale=None,
26
  prenorm: bool = False,
27
  residual_in_fp32: bool = False,
28
+ zero_centered_weight: bool = False,
29
  is_rms_norm: bool = False,
30
  return_dropout_mask: bool = False,
31
  out: Optional[torch.Tensor] = None,
 
62
  If True, returns both the normalized output and the unnormalized input+residual.
63
  residual_in_fp32 (`bool`, *optional*, defaults to False):
64
  If True, performs the residual connection in FP32 precision.
65
+ zero_centered_weight (`bool`, *optional*, defaults to False):
66
+ When set to true, 1.0 is added to the weight before applying it.
67
  is_rms_norm (`bool`, *optional*, defaults to False):
68
  If True, uses RMS normalization instead of layer normalization.
69
  return_dropout_mask (`bool`, *optional*, defaults to False):
torch-ext/triton_layer_norm/layer_norm.py CHANGED
@@ -7,14 +7,40 @@
7
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
 
9
  import math
 
10
 
11
  import torch
12
  import torch.nn.functional as F
13
- from torch.amp import custom_fwd, custom_bwd
14
 
15
  import triton
16
  import triton.language as tl
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  def layer_norm_ref(
20
  x,
@@ -28,6 +54,7 @@ def layer_norm_ref(
28
  dropout_p=0.0,
29
  rowscale=None,
30
  prenorm=False,
 
31
  dropout_mask=None,
32
  dropout_mask1=None,
33
  upcast=False,
@@ -41,6 +68,10 @@ def layer_norm_ref(
41
  x1 = x1.float() if x1 is not None else None
42
  weight1 = weight1.float() if weight1 is not None else None
43
  bias1 = bias1.float() if bias1 is not None else None
 
 
 
 
44
  if x1 is not None:
45
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
46
  if rowscale is not None:
@@ -59,9 +90,9 @@ def layer_norm_ref(
59
  x = x + x1
60
  if residual is not None:
61
  x = (x + residual).to(x.dtype)
62
- out = F.layer_norm(
63
- x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps
64
- ).to(dtype)
65
  if weight1 is None:
66
  return out if not prenorm else (out, x)
67
  else:
@@ -83,6 +114,7 @@ def rms_norm_ref(
83
  dropout_p=0.0,
84
  rowscale=None,
85
  prenorm=False,
 
86
  dropout_mask=None,
87
  dropout_mask1=None,
88
  upcast=False,
@@ -96,6 +128,10 @@ def rms_norm_ref(
96
  x1 = x1.float() if x1 is not None else None
97
  weight1 = weight1.float() if weight1 is not None else None
98
  bias1 = bias1.float() if bias1 is not None else None
 
 
 
 
99
  if x1 is not None:
100
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
101
  if rowscale is not None:
@@ -115,34 +151,26 @@ def rms_norm_ref(
115
  if residual is not None:
116
  x = (x + residual).to(x.dtype)
117
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
118
- out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(
119
- dtype
120
- )
121
  if weight1 is None:
122
  return out if not prenorm else (out, x)
123
  else:
124
- out1 = (
125
- (x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)
126
- ).to(dtype)
127
  return (out, out1) if not prenorm else (out, out1, x)
128
 
129
 
130
  @triton.autotune(
131
- configs=[
132
- triton.Config({}, num_warps=1),
133
- triton.Config({}, num_warps=2),
134
- triton.Config({}, num_warps=4),
135
- triton.Config({}, num_warps=8),
136
- triton.Config({}, num_warps=16),
137
- triton.Config({}, num_warps=32),
138
- ],
139
- key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS"],
140
  )
 
141
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
142
  # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
143
- @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
144
- @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
145
- @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
146
  @triton.jit
147
  def _layer_norm_fwd_1pass_kernel(
148
  X, # pointer to the input
@@ -158,6 +186,7 @@ def _layer_norm_fwd_1pass_kernel(
158
  ROWSCALE,
159
  SEEDS, # Dropout seeds for each row
160
  DROPOUT_MASK,
 
161
  Mean, # pointer to the mean
162
  Rstd, # pointer to the 1/std
163
  stride_x_row, # how much to increase the pointer when moving by 1 row
@@ -170,6 +199,7 @@ def _layer_norm_fwd_1pass_kernel(
170
  N, # number of columns in X
171
  eps, # epsilon to avoid division by zero
172
  dropout_p, # Dropout probability
 
173
  IS_RMS_NORM: tl.constexpr,
174
  BLOCK_N: tl.constexpr,
175
  HAS_RESIDUAL: tl.constexpr,
@@ -203,9 +233,7 @@ def _layer_norm_fwd_1pass_kernel(
203
  if HAS_DROPOUT:
204
  # Compute dropout mask
205
  # 7 rounds is good enough, and reduces register pressure
206
- keep_mask = (
207
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
208
- )
209
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
210
  if STORE_DROPOUT_MASK:
211
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
@@ -218,12 +246,11 @@ def _layer_norm_fwd_1pass_kernel(
218
  # Compute dropout mask
219
  # 7 rounds is good enough, and reduces register pressure
220
  keep_mask = (
221
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
222
- > dropout_p
223
  )
224
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
225
  if STORE_DROPOUT_MASK:
226
- tl.store(DROPOUT_MASK + (M + row) * N + cols, keep_mask, mask=cols < N)
227
  x += x1
228
  if HAS_RESIDUAL:
229
  residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
@@ -243,6 +270,8 @@ def _layer_norm_fwd_1pass_kernel(
243
  # Normalize and apply linear transformation
244
  mask = cols < N
245
  w = tl.load(W + cols, mask=mask).to(tl.float32)
 
 
246
  if HAS_BIAS:
247
  b = tl.load(B + cols, mask=mask).to(tl.float32)
248
  x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
@@ -251,6 +280,8 @@ def _layer_norm_fwd_1pass_kernel(
251
  tl.store(Y + cols, y, mask=mask)
252
  if HAS_W1:
253
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
 
 
254
  if HAS_B1:
255
  b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
256
  y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
@@ -258,25 +289,87 @@ def _layer_norm_fwd_1pass_kernel(
258
 
259
 
260
  def _layer_norm_fwd(
261
- x,
262
- weight,
263
- bias,
264
- eps,
265
- residual=None,
266
- x1=None,
267
- weight1=None,
268
- bias1=None,
269
- dropout_p=0.0,
270
- rowscale=None,
271
- out_dtype=None,
272
- residual_dtype=None,
273
- is_rms_norm=False,
274
- return_dropout_mask=False,
275
- out=None,
276
- residual_out=None,
277
- ):
 
 
 
 
 
 
278
  if residual is not None:
279
  residual_dtype = residual.dtype
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
280
  M, N = x.shape
281
  assert x.stride(-1) == 1
282
  if residual is not None:
@@ -300,41 +393,17 @@ def _layer_norm_fwd(
300
  if rowscale is not None:
301
  assert rowscale.is_contiguous()
302
  assert rowscale.shape == (M,)
303
- # allocate output
304
- if out is None:
305
- out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
306
- else:
307
- assert out.shape == x.shape
308
  assert out.stride(-1) == 1
 
 
 
309
  if weight1 is not None:
310
  y1 = torch.empty_like(out)
311
  assert y1.stride(-1) == 1
312
  else:
313
  y1 = None
314
- if (
315
- residual is not None
316
- or (residual_dtype is not None and residual_dtype != x.dtype)
317
- or dropout_p > 0.0
318
- or rowscale is not None
319
- or x1 is not None
320
- ):
321
- if residual_out is None:
322
- residual_out = torch.empty(
323
- M,
324
- N,
325
- device=x.device,
326
- dtype=residual_dtype if residual_dtype is not None else x.dtype,
327
- )
328
- else:
329
- assert residual_out.shape == x.shape
330
- assert residual_out.stride(-1) == 1
331
- else:
332
- residual_out = None
333
- mean = (
334
- torch.empty((M,), dtype=torch.float32, device=x.device)
335
- if not is_rms_norm
336
- else None
337
- )
338
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
339
  if dropout_p > 0.0:
340
  seeds = torch.randint(
@@ -343,18 +412,20 @@ def _layer_norm_fwd(
343
  else:
344
  seeds = None
345
  if return_dropout_mask and dropout_p > 0.0:
346
- dropout_mask = torch.empty(
347
- M if x1 is None else 2 * M, N, device=x.device, dtype=torch.bool
348
- )
 
 
349
  else:
350
- dropout_mask = None
351
  # Less than 64KB per feature: enqueue fused kernel
352
  MAX_FUSED_SIZE = 65536 // x.element_size()
353
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
354
  if N > BLOCK_N:
355
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
356
  with torch.cuda.device(x.device.index):
357
- _layer_norm_fwd_1pass_kernel[(M,)](
358
  x,
359
  out,
360
  weight,
@@ -368,6 +439,7 @@ def _layer_norm_fwd(
368
  rowscale,
369
  seeds,
370
  dropout_mask,
 
371
  mean,
372
  rstd,
373
  x.stride(0),
@@ -380,6 +452,8 @@ def _layer_norm_fwd(
380
  N,
381
  eps,
382
  dropout_p,
 
 
383
  is_rms_norm,
384
  BLOCK_N,
385
  residual is not None,
@@ -388,50 +462,26 @@ def _layer_norm_fwd(
388
  dropout_p > 0.0,
389
  dropout_mask is not None,
390
  rowscale is not None,
 
 
 
391
  )
392
- # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
393
- if dropout_mask is not None and x1 is not None:
394
- dropout_mask, dropout_mask1 = dropout_mask.tensor_split(2, dim=0)
395
- else:
396
- dropout_mask1 = None
397
- return (
398
- out,
399
- y1,
400
- mean,
401
- rstd,
402
- residual_out if residual_out is not None else x,
403
- seeds,
404
- dropout_mask,
405
- dropout_mask1,
406
- )
407
 
408
 
409
  @triton.autotune(
410
- configs=[
411
- triton.Config({}, num_warps=1),
412
- triton.Config({}, num_warps=2),
413
- triton.Config({}, num_warps=4),
414
- triton.Config({}, num_warps=8),
415
- triton.Config({}, num_warps=16),
416
- triton.Config({}, num_warps=32),
417
- ],
418
- key=[
419
- "N",
420
- "HAS_DRESIDUAL",
421
- "STORE_DRESIDUAL",
422
- "IS_RMS_NORM",
423
- "HAS_BIAS",
424
- "HAS_DROPOUT",
425
- ],
426
  )
 
427
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
428
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
429
  # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
430
- @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
431
- @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
432
- @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
433
- @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
434
- @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
435
  @triton.jit
436
  def _layer_norm_bwd_kernel(
437
  X, # pointer to the input
@@ -465,6 +515,7 @@ def _layer_norm_bwd_kernel(
465
  N, # number of columns in X
466
  eps, # epsilon to avoid division by zero
467
  dropout_p,
 
468
  rows_per_program,
469
  IS_RMS_NORM: tl.constexpr,
470
  BLOCK_N: tl.constexpr,
@@ -498,10 +549,14 @@ def _layer_norm_bwd_kernel(
498
  if RECOMPUTE_OUTPUT:
499
  Y += row_start * stride_y_row
500
  w = tl.load(W + cols, mask=mask).to(tl.float32)
 
 
501
  if RECOMPUTE_OUTPUT and HAS_BIAS:
502
  b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
503
  if HAS_DY1:
504
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
 
 
505
  dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
506
  if HAS_BIAS:
507
  db = tl.zeros((BLOCK_N,), dtype=tl.float32)
@@ -550,18 +605,14 @@ def _layer_norm_bwd_kernel(
550
  if HAS_DX1:
551
  if HAS_DROPOUT:
552
  keep_mask = (
553
- tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7)
554
- > dropout_p
555
  )
556
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
557
  else:
558
  dx1 = dx
559
  tl.store(DX1 + cols, dx1, mask=mask)
560
  if HAS_DROPOUT:
561
- keep_mask = (
562
- tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7)
563
- > dropout_p
564
- )
565
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
566
  if HAS_ROWSCALE:
567
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
@@ -591,31 +642,93 @@ def _layer_norm_bwd_kernel(
591
 
592
 
593
  def _layer_norm_bwd(
594
- dy,
595
- x,
596
- weight,
597
- bias,
598
- eps,
599
- mean,
600
- rstd,
601
- dresidual=None,
602
- dy1=None,
603
- weight1=None,
604
- bias1=None,
605
- seeds=None,
606
- dropout_p=0.0,
607
- rowscale=None,
608
- has_residual=False,
609
- has_x1=False,
610
- is_rms_norm=False,
611
- x_dtype=None,
612
- recompute_output=False,
613
- ):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
614
  M, N = x.shape
615
  assert x.stride(-1) == 1
 
616
  assert dy.stride(-1) == 1
617
  assert dy.shape == (M, N)
618
  if dresidual is not None:
 
619
  assert dresidual.stride(-1) == 1
620
  assert dresidual.shape == (M, N)
621
  assert weight.shape == (N,)
@@ -624,6 +737,7 @@ def _layer_norm_bwd(
624
  assert bias.stride(-1) == 1
625
  assert bias.shape == (N,)
626
  if dy1 is not None:
 
627
  assert weight1 is not None
628
  assert dy1.shape == dy.shape
629
  assert dy1.stride(-1) == 1
@@ -652,22 +766,18 @@ def _layer_norm_bwd(
652
  else None
653
  )
654
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
655
- y = (
656
- torch.empty(M, N, dtype=dy.dtype, device=dy.device)
657
- if recompute_output
658
- else None
659
- )
660
  if recompute_output:
661
- assert (
662
- weight1 is None
663
- ), "recompute_output is not supported with parallel LayerNorm"
664
 
665
  # Less than 64KB per feature: enqueue fused kernel
666
  MAX_FUSED_SIZE = 65536 // x.element_size()
667
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
668
  if N > BLOCK_N:
669
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
670
- sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count
 
 
671
  _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
672
  _db = (
673
  torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
@@ -679,7 +789,7 @@ def _layer_norm_bwd(
679
  rows_per_program = math.ceil(M / sm_count)
680
  grid = (sm_count,)
681
  with torch.cuda.device(x.device.index):
682
- _layer_norm_bwd_kernel[grid](
683
  x,
684
  weight,
685
  bias,
@@ -711,6 +821,8 @@ def _layer_norm_bwd(
711
  N,
712
  eps,
713
  dropout_p,
 
 
714
  rows_per_program,
715
  is_rms_norm,
716
  BLOCK_N,
@@ -718,24 +830,22 @@ def _layer_norm_bwd(
718
  dresidual_in is not None,
719
  bias is not None,
720
  dropout_p > 0.0,
 
 
 
 
 
721
  )
722
  dw = _dw.sum(0).to(weight.dtype)
723
  db = _db.sum(0).to(bias.dtype) if bias is not None else None
724
  dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
725
  db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
726
- # Don't need to compute dresidual_in separately in this case
727
- if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
728
- dresidual_in = dx
729
- if has_x1 and dropout_p == 0.0:
730
- dx1 = dx
731
- return (
732
- (dx, dw, db, dresidual_in, dx1, dw1, db1)
733
- if not recompute_output
734
- else (dx, dw, db, dresidual_in, dx1, dw1, db1, y)
735
- )
736
 
737
 
738
  class LayerNormFn(torch.autograd.Function):
 
739
  @staticmethod
740
  def forward(
741
  ctx,
@@ -751,34 +861,27 @@ class LayerNormFn(torch.autograd.Function):
751
  rowscale=None,
752
  prenorm=False,
753
  residual_in_fp32=False,
 
754
  is_rms_norm=False,
755
  return_dropout_mask=False,
 
756
  out=None,
757
- residual_out=None,
758
  ):
759
  x_shape_og = x.shape
760
  # reshape input data into 2D tensor
761
- x = x.reshape(-1, x.shape[-1])
762
- if x.stride(-1) != 1:
763
- x = x.contiguous()
764
  if residual is not None:
765
  assert residual.shape == x_shape_og
766
- residual = residual.reshape(-1, residual.shape[-1])
767
- if residual.stride(-1) != 1:
768
- residual = residual.contiguous()
769
  if x1 is not None:
770
  assert x1.shape == x_shape_og
771
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
772
- x1 = x1.reshape(-1, x1.shape[-1])
773
- if x1.stride(-1) != 1:
774
- x1 = x1.contiguous()
775
  weight = weight.contiguous()
776
- if bias is not None:
777
- bias = bias.contiguous()
778
- if weight1 is not None:
779
- weight1 = weight1.contiguous()
780
- if bias1 is not None:
781
- bias1 = bias1.contiguous()
782
  if rowscale is not None:
783
  rowscale = rowscale.reshape(-1).contiguous()
784
  residual_dtype = (
@@ -790,24 +893,24 @@ class LayerNormFn(torch.autograd.Function):
790
  out = out.reshape(-1, out.shape[-1])
791
  if residual_out is not None:
792
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
793
- y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = (
794
- _layer_norm_fwd(
795
- x,
796
- weight,
797
- bias,
798
- eps,
799
- residual,
800
- x1,
801
- weight1,
802
- bias1,
803
- dropout_p=dropout_p,
804
- rowscale=rowscale,
805
- residual_dtype=residual_dtype,
806
- is_rms_norm=is_rms_norm,
807
- return_dropout_mask=return_dropout_mask,
808
- out=out,
809
- residual_out=residual_out,
810
- )
811
  )
812
  ctx.save_for_backward(
813
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
@@ -820,17 +923,12 @@ class LayerNormFn(torch.autograd.Function):
820
  ctx.has_x1 = x1 is not None
821
  ctx.prenorm = prenorm
822
  ctx.x_dtype = x.dtype
 
823
  y = y.reshape(x_shape_og)
824
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
825
- residual_out = (
826
- residual_out.reshape(x_shape_og) if residual_out is not None else None
827
- )
828
- dropout_mask = (
829
- dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
830
- )
831
- dropout_mask1 = (
832
- dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
833
- )
834
  if not return_dropout_mask:
835
  if weight1 is None:
836
  return y if not prenorm else (y, residual_out)
@@ -854,26 +952,19 @@ class LayerNormFn(torch.autograd.Function):
854
  def backward(ctx, dy, *args):
855
  x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
856
  dy = dy.reshape(-1, dy.shape[-1])
857
- if dy.stride(-1) != 1:
858
- dy = dy.contiguous()
859
- assert dy.shape == x.shape
860
  if weight1 is not None:
861
  dy1, args = args[0], args[1:]
862
  dy1 = dy1.reshape(-1, dy1.shape[-1])
863
- if dy1.stride(-1) != 1:
864
- dy1 = dy1.contiguous()
865
  assert dy1.shape == x.shape
866
  else:
867
  dy1 = None
868
  if ctx.prenorm:
869
  dresidual = args[0]
870
  dresidual = dresidual.reshape(-1, dresidual.shape[-1])
871
- if dresidual.stride(-1) != 1:
872
- dresidual = dresidual.contiguous()
873
  assert dresidual.shape == x.shape
874
  else:
875
  dresidual = None
876
- dx, dw, db, dresidual_in, dx1, dw1, db1 = _layer_norm_bwd(
877
  dy,
878
  x,
879
  weight,
@@ -890,8 +981,10 @@ class LayerNormFn(torch.autograd.Function):
890
  rowscale,
891
  ctx.has_residual,
892
  ctx.has_x1,
 
893
  ctx.is_rms_norm,
894
  x_dtype=ctx.x_dtype,
 
895
  )
896
  return (
897
  dx.reshape(ctx.x_shape_og),
@@ -910,6 +1003,8 @@ class LayerNormFn(torch.autograd.Function):
910
  None,
911
  None,
912
  None,
 
 
913
  )
914
 
915
 
@@ -926,10 +1021,12 @@ def layer_norm_fn(
926
  rowscale=None,
927
  prenorm=False,
928
  residual_in_fp32=False,
 
929
  is_rms_norm=False,
930
  return_dropout_mask=False,
 
931
  out=None,
932
- residual_out=None,
933
  ):
934
  return LayerNormFn.apply(
935
  x,
@@ -944,10 +1041,12 @@ def layer_norm_fn(
944
  rowscale,
945
  prenorm,
946
  residual_in_fp32,
 
947
  is_rms_norm,
948
  return_dropout_mask,
 
949
  out,
950
- residual_out,
951
  )
952
 
953
 
@@ -964,9 +1063,11 @@ def rms_norm_fn(
964
  rowscale=None,
965
  prenorm=False,
966
  residual_in_fp32=False,
 
967
  return_dropout_mask=False,
 
968
  out=None,
969
- residual_out=None,
970
  ):
971
  return LayerNormFn.apply(
972
  x,
@@ -981,16 +1082,19 @@ def rms_norm_fn(
981
  rowscale,
982
  prenorm,
983
  residual_in_fp32,
 
984
  True,
985
  return_dropout_mask,
 
986
  out,
987
- residual_out,
988
  )
989
 
990
 
991
  class RMSNorm(torch.nn.Module):
992
 
993
- def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, device=None, dtype=None):
 
994
  factory_kwargs = {"device": device, "dtype": dtype}
995
  super().__init__()
996
  self.eps = eps
@@ -998,12 +1102,16 @@ class RMSNorm(torch.nn.Module):
998
  self.drop = torch.nn.Dropout(dropout_p)
999
  else:
1000
  self.drop = None
 
1001
  self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1002
  self.register_parameter("bias", None)
1003
  self.reset_parameters()
1004
 
1005
  def reset_parameters(self):
1006
- torch.nn.init.ones_(self.weight)
 
 
 
1007
 
1008
  def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1009
  return rms_norm_fn(
@@ -1015,12 +1123,14 @@ class RMSNorm(torch.nn.Module):
1015
  dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1016
  prenorm=prenorm,
1017
  residual_in_fp32=residual_in_fp32,
 
1018
  )
1019
 
1020
 
1021
  class LayerNormLinearFn(torch.autograd.Function):
 
1022
  @staticmethod
1023
- @custom_fwd(device_type="cuda")
1024
  def forward(
1025
  ctx,
1026
  x,
@@ -1036,17 +1146,12 @@ class LayerNormLinearFn(torch.autograd.Function):
1036
  ):
1037
  x_shape_og = x.shape
1038
  # reshape input data into 2D tensor
1039
- x = x.reshape(-1, x.shape[-1])
1040
- if x.stride(-1) != 1:
1041
- x = x.contiguous()
1042
  if residual is not None:
1043
  assert residual.shape == x_shape_og
1044
- residual = residual.reshape(-1, residual.shape[-1])
1045
- if residual.stride(-1) != 1:
1046
- residual = residual.contiguous()
1047
  norm_weight = norm_weight.contiguous()
1048
- if norm_bias is not None:
1049
- norm_bias = norm_bias.contiguous()
1050
  residual_dtype = (
1051
  residual.dtype
1052
  if residual is not None
@@ -1058,25 +1163,17 @@ class LayerNormLinearFn(torch.autograd.Function):
1058
  norm_bias,
1059
  eps,
1060
  residual,
1061
- out_dtype=(
1062
- None
1063
- if not torch.is_autocast_enabled()
1064
- else torch.get_autocast_gpu_dtype()
1065
- ),
1066
  residual_dtype=residual_dtype,
1067
  is_rms_norm=is_rms_norm,
1068
  )
1069
  y = y.reshape(x_shape_og)
1070
- dtype = (
1071
- torch.get_autocast_gpu_dtype() if torch.is_autocast_enabled() else y.dtype
1072
- )
1073
  linear_weight = linear_weight.to(dtype)
1074
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1075
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1076
  # We don't store y, will be recomputed in the backward pass to save memory
1077
- ctx.save_for_backward(
1078
- residual_out, norm_weight, norm_bias, linear_weight, mean, rstd
1079
- )
1080
  ctx.x_shape_og = x_shape_og
1081
  ctx.eps = eps
1082
  ctx.is_rms_norm = is_rms_norm
@@ -1087,20 +1184,17 @@ class LayerNormLinearFn(torch.autograd.Function):
1087
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1088
 
1089
  @staticmethod
1090
- @custom_bwd(device_type="cuda")
1091
  def backward(ctx, dout, *args):
1092
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1093
  dout = dout.reshape(-1, dout.shape[-1])
1094
  dy = F.linear(dout, linear_weight.t())
1095
  dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1096
- if dy.stride(-1) != 1:
1097
- dy = dy.contiguous()
1098
  assert dy.shape == x.shape
1099
  if ctx.prenorm:
1100
  dresidual = args[0]
1101
- dresidual = dresidual.reshape(-1, dresidual.shape[-1])
1102
- if dresidual.stride(-1) != 1:
1103
- dresidual = dresidual.contiguous()
1104
  assert dresidual.shape == x.shape
1105
  else:
1106
  dresidual = None
 
7
  # The models we train have hidden dim up to 8k anyway (e.g. Llama 70B), so this is fine.
8
 
9
  import math
10
+ from typing import Optional, List
11
 
12
  import torch
13
  import torch.nn.functional as F
14
+ from torch import Tensor
15
 
16
  import triton
17
  import triton.language as tl
18
 
19
+ from ._ops import add_op_namespace_prefix
20
+ from .utils.torch import custom_fwd, custom_bwd
21
+ from .utils.library import triton_op
22
+
23
+
24
+ def maybe_contiguous_lastdim(x):
25
+ return x.contiguous() if x is not None and x.stride(-1) != 1 else x
26
+
27
+
28
+ def maybe_contiguous(x):
29
+ return x.contiguous() if x is not None else None
30
+
31
+
32
+ def triton_autotune_configs():
33
+ # Return configs with a valid warp count for the current device
34
+ configs = []
35
+ # Maximum threads per block is architecture-dependent in theory, but in reality all are 1024
36
+ max_threads_per_block = 1024
37
+ # Default to warp size 32 if not defined by device
38
+ warp_size = getattr(torch.cuda.get_device_properties(torch.cuda.current_device()), "warp_size", 32)
39
+ # Autotune for warp counts which are powers of 2 and do not exceed thread per block limit
40
+ return [triton.Config({}, num_warps=warp_count) for warp_count in [1, 2, 4, 8, 16, 32]
41
+ if warp_count * warp_size <= max_threads_per_block]
42
+ # return [triton.Config({}, num_warps=8)]
43
+
44
 
45
  def layer_norm_ref(
46
  x,
 
54
  dropout_p=0.0,
55
  rowscale=None,
56
  prenorm=False,
57
+ zero_centered_weight=False,
58
  dropout_mask=None,
59
  dropout_mask1=None,
60
  upcast=False,
 
68
  x1 = x1.float() if x1 is not None else None
69
  weight1 = weight1.float() if weight1 is not None else None
70
  bias1 = bias1.float() if bias1 is not None else None
71
+ if zero_centered_weight:
72
+ weight = weight + 1.0
73
+ if weight1 is not None:
74
+ weight1 = weight1 + 1.0
75
  if x1 is not None:
76
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
77
  if rowscale is not None:
 
90
  x = x + x1
91
  if residual is not None:
92
  x = (x + residual).to(x.dtype)
93
+ out = F.layer_norm(x.to(weight.dtype), x.shape[-1:], weight=weight, bias=bias, eps=eps).to(
94
+ dtype
95
+ )
96
  if weight1 is None:
97
  return out if not prenorm else (out, x)
98
  else:
 
114
  dropout_p=0.0,
115
  rowscale=None,
116
  prenorm=False,
117
+ zero_centered_weight=False,
118
  dropout_mask=None,
119
  dropout_mask1=None,
120
  upcast=False,
 
128
  x1 = x1.float() if x1 is not None else None
129
  weight1 = weight1.float() if weight1 is not None else None
130
  bias1 = bias1.float() if bias1 is not None else None
131
+ if zero_centered_weight:
132
+ weight = weight + 1.0
133
+ if weight1 is not None:
134
+ weight1 = weight1 + 1.0
135
  if x1 is not None:
136
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
137
  if rowscale is not None:
 
151
  if residual is not None:
152
  x = (x + residual).to(x.dtype)
153
  rstd = 1 / torch.sqrt((x.square()).mean(dim=-1, keepdim=True) + eps)
154
+ out = ((x * rstd * weight) + bias if bias is not None else (x * rstd * weight)).to(dtype)
 
 
155
  if weight1 is None:
156
  return out if not prenorm else (out, x)
157
  else:
158
+ out1 = ((x * rstd * weight1) + bias1 if bias1 is not None else (x * rstd * weight1)).to(
159
+ dtype
160
+ )
161
  return (out, out1) if not prenorm else (out, out1, x)
162
 
163
 
164
  @triton.autotune(
165
+ configs=triton_autotune_configs(),
166
+ key=["N", "HAS_RESIDUAL", "STORE_RESIDUAL_OUT", "IS_RMS_NORM", "HAS_BIAS", "HAS_X1", "HAS_W1", "HAS_B1"],
 
 
 
 
 
 
 
167
  )
168
+ # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
169
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
170
  # @triton.heuristics({"HAS_RESIDUAL": lambda args: args["RESIDUAL"] is not None})
171
+ # @triton.heuristics({"HAS_X1": lambda args: args["X1"] is not None})
172
+ # @triton.heuristics({"HAS_W1": lambda args: args["W1"] is not None})
173
+ # @triton.heuristics({"HAS_B1": lambda args: args["B1"] is not None})
174
  @triton.jit
175
  def _layer_norm_fwd_1pass_kernel(
176
  X, # pointer to the input
 
186
  ROWSCALE,
187
  SEEDS, # Dropout seeds for each row
188
  DROPOUT_MASK,
189
+ DROPOUT_MASK1,
190
  Mean, # pointer to the mean
191
  Rstd, # pointer to the 1/std
192
  stride_x_row, # how much to increase the pointer when moving by 1 row
 
199
  N, # number of columns in X
200
  eps, # epsilon to avoid division by zero
201
  dropout_p, # Dropout probability
202
+ zero_centered_weight, # If true, add 1.0 to the weight
203
  IS_RMS_NORM: tl.constexpr,
204
  BLOCK_N: tl.constexpr,
205
  HAS_RESIDUAL: tl.constexpr,
 
233
  if HAS_DROPOUT:
234
  # Compute dropout mask
235
  # 7 rounds is good enough, and reduces register pressure
236
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
237
  x = tl.where(keep_mask, x / (1.0 - dropout_p), 0.0)
238
  if STORE_DROPOUT_MASK:
239
  tl.store(DROPOUT_MASK + row * N + cols, keep_mask, mask=cols < N)
 
246
  # Compute dropout mask
247
  # 7 rounds is good enough, and reduces register pressure
248
  keep_mask = (
249
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
250
  )
251
  x1 = tl.where(keep_mask, x1 / (1.0 - dropout_p), 0.0)
252
  if STORE_DROPOUT_MASK:
253
+ tl.store(DROPOUT_MASK1 + row * N + cols, keep_mask, mask=cols < N)
254
  x += x1
255
  if HAS_RESIDUAL:
256
  residual = tl.load(RESIDUAL + cols, mask=cols < N, other=0.0).to(tl.float32)
 
270
  # Normalize and apply linear transformation
271
  mask = cols < N
272
  w = tl.load(W + cols, mask=mask).to(tl.float32)
273
+ if zero_centered_weight:
274
+ w += 1.0
275
  if HAS_BIAS:
276
  b = tl.load(B + cols, mask=mask).to(tl.float32)
277
  x_hat = (x - mean) * rstd if not IS_RMS_NORM else x * rstd
 
280
  tl.store(Y + cols, y, mask=mask)
281
  if HAS_W1:
282
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
283
+ if zero_centered_weight:
284
+ w1 += 1.0
285
  if HAS_B1:
286
  b1 = tl.load(B1 + cols, mask=mask).to(tl.float32)
287
  y1 = x_hat * w1 + b1 if HAS_B1 else x_hat * w1
 
289
 
290
 
291
  def _layer_norm_fwd(
292
+ x: Tensor,
293
+ weight: Tensor,
294
+ bias: Tensor,
295
+ eps: float,
296
+ residual: Optional[Tensor] = None,
297
+ x1: Optional[Tensor] = None,
298
+ weight1: Optional[Tensor] = None,
299
+ bias1: Optional[Tensor] = None,
300
+ dropout_p: float = 0.0,
301
+ rowscale: Optional[Tensor] = None,
302
+ out_dtype: Optional[torch.dtype] = None,
303
+ residual_dtype: Optional[torch.dtype] = None,
304
+ zero_centered_weight: bool = False,
305
+ is_rms_norm: bool = False,
306
+ return_dropout_mask: bool = False,
307
+ out: Optional[Tensor] = None,
308
+ residual_out: Optional[Tensor] = None
309
+ ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
310
+ # Need to wrap to handle the case where residual_out is a alias of x, which makes torch.library
311
+ # and torch.compile unhappy. Also allocate memory for out and residual_out if they are None
312
+ # so that _layer_norm_fwd_impl doesn't have to return them.
313
+ if out is None:
314
+ out = torch.empty_like(x, dtype=x.dtype if out_dtype is None else out_dtype)
315
  if residual is not None:
316
  residual_dtype = residual.dtype
317
+ if residual_out is None and (
318
+ residual is not None
319
+ or (residual_dtype is not None and residual_dtype != x.dtype)
320
+ or dropout_p > 0.0
321
+ or rowscale is not None
322
+ or x1 is not None
323
+ ):
324
+ residual_out = torch.empty_like(
325
+ x, dtype=residual_dtype if residual_dtype is not None else x.dtype
326
+ )
327
+ else:
328
+ residual_out = None
329
+ y1, mean, rstd, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd_impl(
330
+ x,
331
+ weight,
332
+ bias,
333
+ eps,
334
+ out,
335
+ residual=residual,
336
+ x1=x1,
337
+ weight1=weight1,
338
+ bias1=bias1,
339
+ dropout_p=dropout_p,
340
+ rowscale=rowscale,
341
+ zero_centered_weight=zero_centered_weight,
342
+ is_rms_norm=is_rms_norm,
343
+ return_dropout_mask=return_dropout_mask,
344
+ residual_out=residual_out,
345
+ )
346
+ # residual_out is None if residual is None and residual_dtype == input_dtype and dropout_p == 0.0
347
+ if residual_out is None:
348
+ residual_out = x
349
+ return out, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1
350
+
351
+
352
+ # [2025-04-28] torch.library.triton_op ignores the schema argument, but here we need the schema
353
+ # since we're returning a tuple of tensors
354
+ @triton_op(add_op_namespace_prefix("layer_norm_fwd_impl"), mutates_args={"out", "residual_out"},
355
+ schema="(Tensor x, Tensor weight, Tensor bias, float eps, Tensor(a!) out, Tensor? residual, Tensor? x1, Tensor? weight1, Tensor? bias1, float dropout_p, Tensor? rowscale, bool zero_centered_weight, bool is_rms_norm, bool return_dropout_mask, Tensor(a!)? residual_out) -> (Tensor y1, Tensor mean, Tensor rstd, Tensor seeds, Tensor dropout_mask, Tensor dropout_mask1)")
356
+ def _layer_norm_fwd_impl(
357
+ x: Tensor,
358
+ weight: Tensor,
359
+ bias: Tensor,
360
+ eps: float,
361
+ out: Tensor,
362
+ residual: Optional[Tensor] = None,
363
+ x1: Optional[Tensor] = None,
364
+ weight1: Optional[Tensor] = None,
365
+ bias1: Optional[Tensor] = None,
366
+ dropout_p: float = 0.0,
367
+ rowscale: Optional[Tensor] = None,
368
+ zero_centered_weight: bool = False,
369
+ is_rms_norm: bool = False,
370
+ return_dropout_mask: bool = False,
371
+ residual_out: Optional[Tensor] = None
372
+ ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
373
  M, N = x.shape
374
  assert x.stride(-1) == 1
375
  if residual is not None:
 
393
  if rowscale is not None:
394
  assert rowscale.is_contiguous()
395
  assert rowscale.shape == (M,)
396
+ assert out.shape == x.shape
 
 
 
 
397
  assert out.stride(-1) == 1
398
+ if residual_out is not None:
399
+ assert residual_out.shape == x.shape
400
+ assert residual_out.stride(-1) == 1
401
  if weight1 is not None:
402
  y1 = torch.empty_like(out)
403
  assert y1.stride(-1) == 1
404
  else:
405
  y1 = None
406
+ mean = torch.empty((M,), dtype=torch.float32, device=x.device) if not is_rms_norm else None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
407
  rstd = torch.empty((M,), dtype=torch.float32, device=x.device)
408
  if dropout_p > 0.0:
409
  seeds = torch.randint(
 
412
  else:
413
  seeds = None
414
  if return_dropout_mask and dropout_p > 0.0:
415
+ dropout_mask = torch.empty(M, N, device=x.device, dtype=torch.bool)
416
+ if x1 is not None:
417
+ dropout_mask1 = torch.empty(M, N, device=x.device, dtype=torch.bool)
418
+ else:
419
+ dropout_mask1 = None
420
  else:
421
+ dropout_mask, dropout_mask1 = None, None
422
  # Less than 64KB per feature: enqueue fused kernel
423
  MAX_FUSED_SIZE = 65536 // x.element_size()
424
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
425
  if N > BLOCK_N:
426
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
427
  with torch.cuda.device(x.device.index):
428
+ torch.library.wrap_triton(_layer_norm_fwd_1pass_kernel)[(M,)](
429
  x,
430
  out,
431
  weight,
 
439
  rowscale,
440
  seeds,
441
  dropout_mask,
442
+ dropout_mask1,
443
  mean,
444
  rstd,
445
  x.stride(0),
 
452
  N,
453
  eps,
454
  dropout_p,
455
+ # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
456
+ int(zero_centered_weight),
457
  is_rms_norm,
458
  BLOCK_N,
459
  residual is not None,
 
462
  dropout_p > 0.0,
463
  dropout_mask is not None,
464
  rowscale is not None,
465
+ HAS_X1=x1 is not None,
466
+ HAS_W1=weight1 is not None,
467
+ HAS_B1=bias1 is not None,
468
  )
469
+ return y1, mean, rstd, seeds, dropout_mask, dropout_mask1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
470
 
471
 
472
  @triton.autotune(
473
+ configs=triton_autotune_configs(),
474
+ key=["N", "HAS_DRESIDUAL", "STORE_DRESIDUAL", "IS_RMS_NORM", "HAS_BIAS", "HAS_DROPOUT"],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
475
  )
476
+ # torch compile doesn't like triton.heuristics, so we set these manually when calling the kernel
477
  # @triton.heuristics({"HAS_BIAS": lambda args: args["B"] is not None})
478
  # @triton.heuristics({"HAS_DRESIDUAL": lambda args: args["DRESIDUAL"] is not None})
479
  # @triton.heuristics({"STORE_DRESIDUAL": lambda args: args["DRESIDUAL_IN"] is not None})
480
+ # @triton.heuristics({"HAS_ROWSCALE": lambda args: args["ROWSCALE"] is not None})
481
+ # @triton.heuristics({"HAS_DY1": lambda args: args["DY1"] is not None})
482
+ # @triton.heuristics({"HAS_DX1": lambda args: args["DX1"] is not None})
483
+ # @triton.heuristics({"HAS_B1": lambda args: args["DB1"] is not None})
484
+ # @triton.heuristics({"RECOMPUTE_OUTPUT": lambda args: args["Y"] is not None})
485
  @triton.jit
486
  def _layer_norm_bwd_kernel(
487
  X, # pointer to the input
 
515
  N, # number of columns in X
516
  eps, # epsilon to avoid division by zero
517
  dropout_p,
518
+ zero_centered_weight,
519
  rows_per_program,
520
  IS_RMS_NORM: tl.constexpr,
521
  BLOCK_N: tl.constexpr,
 
549
  if RECOMPUTE_OUTPUT:
550
  Y += row_start * stride_y_row
551
  w = tl.load(W + cols, mask=mask).to(tl.float32)
552
+ if zero_centered_weight:
553
+ w += 1.0
554
  if RECOMPUTE_OUTPUT and HAS_BIAS:
555
  b = tl.load(B + cols, mask=mask, other=0.0).to(tl.float32)
556
  if HAS_DY1:
557
  w1 = tl.load(W1 + cols, mask=mask).to(tl.float32)
558
+ if zero_centered_weight:
559
+ w1 += 1.0
560
  dw = tl.zeros((BLOCK_N,), dtype=tl.float32)
561
  if HAS_BIAS:
562
  db = tl.zeros((BLOCK_N,), dtype=tl.float32)
 
605
  if HAS_DX1:
606
  if HAS_DROPOUT:
607
  keep_mask = (
608
+ tl.rand(tl.load(SEEDS + M + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
609
  )
610
  dx1 = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
611
  else:
612
  dx1 = dx
613
  tl.store(DX1 + cols, dx1, mask=mask)
614
  if HAS_DROPOUT:
615
+ keep_mask = tl.rand(tl.load(SEEDS + row).to(tl.uint32), cols, n_rounds=7) > dropout_p
 
 
 
616
  dx = tl.where(keep_mask, dx / (1.0 - dropout_p), 0.0)
617
  if HAS_ROWSCALE:
618
  rowscale = tl.load(ROWSCALE + row).to(tl.float32)
 
642
 
643
 
644
  def _layer_norm_bwd(
645
+ dy: Tensor,
646
+ x: Tensor,
647
+ weight: Tensor,
648
+ bias: Tensor,
649
+ eps: float,
650
+ mean: Tensor,
651
+ rstd: Tensor,
652
+ dresidual: Optional[Tensor] = None,
653
+ dy1: Optional[Tensor] = None,
654
+ weight1: Optional[Tensor] = None,
655
+ bias1: Optional[Tensor] = None,
656
+ seeds: Optional[Tensor] = None,
657
+ dropout_p: float = 0.0,
658
+ rowscale: Optional[Tensor] = None,
659
+ has_residual: bool = False,
660
+ has_x1: bool = False,
661
+ zero_centered_weight: bool = False,
662
+ is_rms_norm: bool = False,
663
+ x_dtype: Optional[torch.dtype] = None,
664
+ recompute_output: bool = False,
665
+ ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
666
+ # Need to wrap to handle the case where dresidual_in or dx1 are aliases of x,
667
+ # which makes torch.library unhappy
668
+ dx, dw, db, dresidual_in, dx1, dw1, db1, y = _layer_norm_bwd_impl(
669
+ dy,
670
+ x,
671
+ weight,
672
+ bias,
673
+ eps,
674
+ mean,
675
+ rstd,
676
+ dresidual,
677
+ dy1,
678
+ weight1,
679
+ bias1,
680
+ seeds,
681
+ dropout_p,
682
+ rowscale,
683
+ has_residual,
684
+ has_x1,
685
+ zero_centered_weight,
686
+ is_rms_norm,
687
+ x_dtype=x_dtype,
688
+ recompute_output=recompute_output,
689
+ )
690
+ # Don't need to compute dresidual_in separately in this case
691
+ if has_residual and dx.dtype == x.dtype and dropout_p == 0.0 and rowscale is None:
692
+ dresidual_in = dx
693
+ if has_x1 and dropout_p == 0.0:
694
+ dx1 = dx
695
+ return dx, dw, db, dresidual_in, dx1, dw1, db1, y
696
+
697
+
698
+
699
+ @triton_op(add_op_namespace_prefix("layer_norm_bwd_impl"), mutates_args={},
700
+ schema="(Tensor dy, Tensor x, Tensor weight, Tensor bias, float eps, Tensor mean, Tensor rstd, Tensor? dresidual, Tensor? dy1, Tensor? weight1, Tensor? bias1, Tensor? seeds, float dropout_p, Tensor? rowscale, bool has_residual, bool has_x1, bool zero_centered_weight, bool is_rms_norm, ScalarType? x_dtype, bool recompute_output) -> (Tensor dx, Tensor dw, Tensor db, Tensor dresidual_in, Tensor dx1, Tensor dw1, Tensor db1, Tensor y)",
701
+ allow_decomposition=False, # Don't let torch.compile trace inside
702
+ )
703
+ def _layer_norm_bwd_impl(
704
+ dy: Tensor,
705
+ x: Tensor,
706
+ weight: Tensor,
707
+ bias: Tensor,
708
+ eps: float,
709
+ mean: Tensor,
710
+ rstd: Tensor,
711
+ dresidual: Optional[Tensor] = None,
712
+ dy1: Optional[Tensor] = None,
713
+ weight1: Optional[Tensor] = None,
714
+ bias1: Optional[Tensor] = None,
715
+ seeds: Optional[Tensor] = None,
716
+ dropout_p: float = 0.0,
717
+ rowscale: Optional[Tensor] = None,
718
+ has_residual: bool = False,
719
+ has_x1: bool = False,
720
+ zero_centered_weight: bool = False,
721
+ is_rms_norm: bool = False,
722
+ x_dtype: Optional[torch.dtype] = None,
723
+ recompute_output: bool = False,
724
+ ) -> (Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor, Tensor):
725
  M, N = x.shape
726
  assert x.stride(-1) == 1
727
+ dy = maybe_contiguous_lastdim(dy)
728
  assert dy.stride(-1) == 1
729
  assert dy.shape == (M, N)
730
  if dresidual is not None:
731
+ dresidual = maybe_contiguous_lastdim(dresidual)
732
  assert dresidual.stride(-1) == 1
733
  assert dresidual.shape == (M, N)
734
  assert weight.shape == (N,)
 
737
  assert bias.stride(-1) == 1
738
  assert bias.shape == (N,)
739
  if dy1 is not None:
740
+ dy1 = maybe_contiguous_lastdim(dy1)
741
  assert weight1 is not None
742
  assert dy1.shape == dy.shape
743
  assert dy1.stride(-1) == 1
 
766
  else None
767
  )
768
  dx1 = torch.empty_like(dx) if (has_x1 and dropout_p > 0.0) else None
769
+ y = torch.empty(M, N, dtype=dy.dtype, device=dy.device) if recompute_output else None
 
 
 
 
770
  if recompute_output:
771
+ assert weight1 is None, "recompute_output is not supported with parallel LayerNorm"
 
 
772
 
773
  # Less than 64KB per feature: enqueue fused kernel
774
  MAX_FUSED_SIZE = 65536 // x.element_size()
775
  BLOCK_N = min(MAX_FUSED_SIZE, triton.next_power_of_2(N))
776
  if N > BLOCK_N:
777
  raise RuntimeError("This layer norm doesn't support feature dim >= 64KB.")
778
+ # Increasing the multiple (e.g. 8) will allow more thread blocks to be launched and hide the
779
+ # latency of the gmem reads/writes, but will increase the time of summing up dw / db.
780
+ sm_count = torch.cuda.get_device_properties(x.device).multi_processor_count * 8
781
  _dw = torch.empty((sm_count, N), dtype=torch.float32, device=weight.device)
782
  _db = (
783
  torch.empty((sm_count, N), dtype=torch.float32, device=bias.device)
 
789
  rows_per_program = math.ceil(M / sm_count)
790
  grid = (sm_count,)
791
  with torch.cuda.device(x.device.index):
792
+ torch.library.wrap_triton(_layer_norm_bwd_kernel)[grid](
793
  x,
794
  weight,
795
  bias,
 
821
  N,
822
  eps,
823
  dropout_p,
824
+ # Passing bool make torch inductor very unhappy since it then tries to compare to int_max
825
+ int(zero_centered_weight),
826
  rows_per_program,
827
  is_rms_norm,
828
  BLOCK_N,
 
830
  dresidual_in is not None,
831
  bias is not None,
832
  dropout_p > 0.0,
833
+ HAS_ROWSCALE=rowscale is not None,
834
+ HAS_DY1=dy1 is not None,
835
+ HAS_DX1=dx1 is not None,
836
+ HAS_B1=bias1 is not None,
837
+ RECOMPUTE_OUTPUT=y is not None,
838
  )
839
  dw = _dw.sum(0).to(weight.dtype)
840
  db = _db.sum(0).to(bias.dtype) if bias is not None else None
841
  dw1 = _dw1.sum(0).to(weight1.dtype) if weight1 is not None else None
842
  db1 = _db1.sum(0).to(bias1.dtype) if bias1 is not None else None
843
+ # dresidual_in and dx1 could be None, the wrapper will handle assigning them from dx
844
+ return dx, dw, db, dresidual_in, dx1, dw1, db1, y
 
 
 
 
 
 
 
 
845
 
846
 
847
  class LayerNormFn(torch.autograd.Function):
848
+
849
  @staticmethod
850
  def forward(
851
  ctx,
 
861
  rowscale=None,
862
  prenorm=False,
863
  residual_in_fp32=False,
864
+ zero_centered_weight=False,
865
  is_rms_norm=False,
866
  return_dropout_mask=False,
867
+ out_dtype=None,
868
  out=None,
869
+ residual_out=None
870
  ):
871
  x_shape_og = x.shape
872
  # reshape input data into 2D tensor
873
+ x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
 
 
874
  if residual is not None:
875
  assert residual.shape == x_shape_og
876
+ residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
 
 
877
  if x1 is not None:
878
  assert x1.shape == x_shape_og
879
  assert rowscale is None, "rowscale is not supported with parallel LayerNorm"
880
+ x1 = maybe_contiguous_lastdim(x1.reshape(-1, x1.shape[-1]))
 
 
881
  weight = weight.contiguous()
882
+ bias = maybe_contiguous(bias)
883
+ weight1 = maybe_contiguous(weight1)
884
+ bias1 = maybe_contiguous(bias1)
 
 
 
885
  if rowscale is not None:
886
  rowscale = rowscale.reshape(-1).contiguous()
887
  residual_dtype = (
 
893
  out = out.reshape(-1, out.shape[-1])
894
  if residual_out is not None:
895
  residual_out = residual_out.reshape(-1, residual_out.shape[-1])
896
+ y, y1, mean, rstd, residual_out, seeds, dropout_mask, dropout_mask1 = _layer_norm_fwd(
897
+ x,
898
+ weight,
899
+ bias,
900
+ eps,
901
+ residual,
902
+ x1,
903
+ weight1,
904
+ bias1,
905
+ dropout_p=dropout_p,
906
+ rowscale=rowscale,
907
+ out_dtype=out_dtype,
908
+ residual_dtype=residual_dtype,
909
+ zero_centered_weight=zero_centered_weight,
910
+ is_rms_norm=is_rms_norm,
911
+ return_dropout_mask=return_dropout_mask,
912
+ out=out,
913
+ residual_out=residual_out,
914
  )
915
  ctx.save_for_backward(
916
  residual_out, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd
 
923
  ctx.has_x1 = x1 is not None
924
  ctx.prenorm = prenorm
925
  ctx.x_dtype = x.dtype
926
+ ctx.zero_centered_weight = zero_centered_weight
927
  y = y.reshape(x_shape_og)
928
  y1 = y1.reshape(x_shape_og) if y1 is not None else None
929
+ residual_out = residual_out.reshape(x_shape_og) if residual_out is not None else None
930
+ dropout_mask = dropout_mask.reshape(x_shape_og) if dropout_mask is not None else None
931
+ dropout_mask1 = dropout_mask1.reshape(x_shape_og) if dropout_mask1 is not None else None
 
 
 
 
 
 
932
  if not return_dropout_mask:
933
  if weight1 is None:
934
  return y if not prenorm else (y, residual_out)
 
952
  def backward(ctx, dy, *args):
953
  x, weight, bias, weight1, bias1, rowscale, seeds, mean, rstd = ctx.saved_tensors
954
  dy = dy.reshape(-1, dy.shape[-1])
 
 
 
955
  if weight1 is not None:
956
  dy1, args = args[0], args[1:]
957
  dy1 = dy1.reshape(-1, dy1.shape[-1])
 
 
958
  assert dy1.shape == x.shape
959
  else:
960
  dy1 = None
961
  if ctx.prenorm:
962
  dresidual = args[0]
963
  dresidual = dresidual.reshape(-1, dresidual.shape[-1])
 
 
964
  assert dresidual.shape == x.shape
965
  else:
966
  dresidual = None
967
+ dx, dw, db, dresidual_in, dx1, dw1, db1, _ = _layer_norm_bwd(
968
  dy,
969
  x,
970
  weight,
 
981
  rowscale,
982
  ctx.has_residual,
983
  ctx.has_x1,
984
+ ctx.zero_centered_weight,
985
  ctx.is_rms_norm,
986
  x_dtype=ctx.x_dtype,
987
+ recompute_output=False,
988
  )
989
  return (
990
  dx.reshape(ctx.x_shape_og),
 
1003
  None,
1004
  None,
1005
  None,
1006
+ None,
1007
+ None,
1008
  )
1009
 
1010
 
 
1021
  rowscale=None,
1022
  prenorm=False,
1023
  residual_in_fp32=False,
1024
+ zero_centered_weight=False,
1025
  is_rms_norm=False,
1026
  return_dropout_mask=False,
1027
+ out_dtype=None,
1028
  out=None,
1029
+ residual_out=None
1030
  ):
1031
  return LayerNormFn.apply(
1032
  x,
 
1041
  rowscale,
1042
  prenorm,
1043
  residual_in_fp32,
1044
+ zero_centered_weight,
1045
  is_rms_norm,
1046
  return_dropout_mask,
1047
+ out_dtype,
1048
  out,
1049
+ residual_out
1050
  )
1051
 
1052
 
 
1063
  rowscale=None,
1064
  prenorm=False,
1065
  residual_in_fp32=False,
1066
+ zero_centered_weight=False,
1067
  return_dropout_mask=False,
1068
+ out_dtype=None,
1069
  out=None,
1070
+ residual_out=None
1071
  ):
1072
  return LayerNormFn.apply(
1073
  x,
 
1082
  rowscale,
1083
  prenorm,
1084
  residual_in_fp32,
1085
+ zero_centered_weight,
1086
  True,
1087
  return_dropout_mask,
1088
+ out_dtype,
1089
  out,
1090
+ residual_out
1091
  )
1092
 
1093
 
1094
  class RMSNorm(torch.nn.Module):
1095
 
1096
+ def __init__(self, hidden_size, eps=1e-5, dropout_p=0.0, zero_centered_weight=False,
1097
+ device=None, dtype=None):
1098
  factory_kwargs = {"device": device, "dtype": dtype}
1099
  super().__init__()
1100
  self.eps = eps
 
1102
  self.drop = torch.nn.Dropout(dropout_p)
1103
  else:
1104
  self.drop = None
1105
+ self.zero_centered_weight = zero_centered_weight
1106
  self.weight = torch.nn.Parameter(torch.empty(hidden_size, **factory_kwargs))
1107
  self.register_parameter("bias", None)
1108
  self.reset_parameters()
1109
 
1110
  def reset_parameters(self):
1111
+ if not self.zero_centered_weight:
1112
+ torch.nn.init.ones_(self.weight)
1113
+ else:
1114
+ torch.nn.init.zeros_(self.weight)
1115
 
1116
  def forward(self, x, residual=None, prenorm=False, residual_in_fp32=False):
1117
  return rms_norm_fn(
 
1123
  dropout_p=self.drop.p if self.drop is not None and self.training else 0.0,
1124
  prenorm=prenorm,
1125
  residual_in_fp32=residual_in_fp32,
1126
+ zero_centered_weight=self.zero_centered_weight,
1127
  )
1128
 
1129
 
1130
  class LayerNormLinearFn(torch.autograd.Function):
1131
+
1132
  @staticmethod
1133
+ @custom_fwd
1134
  def forward(
1135
  ctx,
1136
  x,
 
1146
  ):
1147
  x_shape_og = x.shape
1148
  # reshape input data into 2D tensor
1149
+ x = maybe_contiguous_lastdim(x.reshape(-1, x.shape[-1]))
 
 
1150
  if residual is not None:
1151
  assert residual.shape == x_shape_og
1152
+ residual = maybe_contiguous_lastdim(residual.reshape(-1, residual.shape[-1]))
 
 
1153
  norm_weight = norm_weight.contiguous()
1154
+ norm_bias = maybe_contiguous(norm_bias)
 
1155
  residual_dtype = (
1156
  residual.dtype
1157
  if residual is not None
 
1163
  norm_bias,
1164
  eps,
1165
  residual,
1166
+ out_dtype=None if not torch.is_autocast_enabled() else torch.get_autocast_dtype("cuda"),
 
 
 
 
1167
  residual_dtype=residual_dtype,
1168
  is_rms_norm=is_rms_norm,
1169
  )
1170
  y = y.reshape(x_shape_og)
1171
+ dtype = torch.get_autocast_dtype("cuda") if torch.is_autocast_enabled() else y.dtype
 
 
1172
  linear_weight = linear_weight.to(dtype)
1173
  linear_bias = linear_bias.to(dtype) if linear_bias is not None else None
1174
  out = F.linear(y.to(linear_weight.dtype), linear_weight, linear_bias)
1175
  # We don't store y, will be recomputed in the backward pass to save memory
1176
+ ctx.save_for_backward(residual_out, norm_weight, norm_bias, linear_weight, mean, rstd)
 
 
1177
  ctx.x_shape_og = x_shape_og
1178
  ctx.eps = eps
1179
  ctx.is_rms_norm = is_rms_norm
 
1184
  return out if not prenorm else (out, residual_out.reshape(x_shape_og))
1185
 
1186
  @staticmethod
1187
+ @custom_bwd
1188
  def backward(ctx, dout, *args):
1189
  x, norm_weight, norm_bias, linear_weight, mean, rstd = ctx.saved_tensors
1190
  dout = dout.reshape(-1, dout.shape[-1])
1191
  dy = F.linear(dout, linear_weight.t())
1192
  dlinear_bias = None if ctx.linear_bias_is_none else dout.sum(0)
1193
+ dy = maybe_contiguous_lastdim(dy)
 
1194
  assert dy.shape == x.shape
1195
  if ctx.prenorm:
1196
  dresidual = args[0]
1197
+ dresidual = maybe_contiguous_lastdim(dresidual.reshape(-1, dresidual.shape[-1]))
 
 
1198
  assert dresidual.shape == x.shape
1199
  else:
1200
  dresidual = None
torch-ext/triton_layer_norm/utils/__init__.py ADDED
File without changes
torch-ext/triton_layer_norm/utils/library.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from https://github.com/pytorch/pytorch/blob/v2.7.0/torch/_library/triton.py
2
+ # The PyTorch implementation simply ignores the schema argument, we simply modify it to use schema.
3
+
4
+ from typing import Optional, Callable, Iterable, Union
5
+
6
+ from torch.library import custom_op, CustomOpDef
7
+ from torch._library.triton import set_wrap_triton_enabled
8
+
9
+
10
+ def triton_op(
11
+ name: str,
12
+ fn: Optional[Callable] = None,
13
+ /,
14
+ *,
15
+ mutates_args: Union[str, Iterable[str]],
16
+ schema: Optional[str] = None,
17
+ # If allow_decomposition=True, this matches torch.library.triton_op behavior. If set to False,
18
+ # then it behaves like torch.library.custom_op instead, which doesn't decompose the operator
19
+ # and so inductor can't trace inside.
20
+ allow_decomposition=True,
21
+ ) -> Callable:
22
+ def dec(fn: Callable[..., object]) -> CustomOpDef:
23
+ def backend_fn(*args, **kwargs): # type: ignore[no-untyped-def]
24
+ # Optimization: we're passing regular Tensors into the triton kernel, so
25
+ # no need to go through HOP dispatch
26
+ with set_wrap_triton_enabled(False):
27
+ return fn(*args, **kwargs)
28
+
29
+ result = custom_op(
30
+ name,
31
+ backend_fn,
32
+ mutates_args=mutates_args,
33
+ # This is the only difference with the PyTorch implementation
34
+ schema=schema,
35
+ )
36
+ from torch._subclasses.functional_tensor import FunctionalTensorMode
37
+
38
+ # We require that the user pass us a function that is make_fx traceable,
39
+ # so we can just register it as the Fake/meta kernel.
40
+ result.register_fake(fn)
41
+
42
+ if allow_decomposition:
43
+ # We decompose the operator when FunctionalTensorMode is active.
44
+ # The goal is to decompose the operator in AOTDispatcher.
45
+ # - With torch.compile, this means that the backend (usually Inductor)
46
+ # can see a call to the triton kernel(s) and so it can directly optimize
47
+ # them by inlining them into the lowering process.
48
+ def functional_decomp( # type: ignore[no-untyped-def]
49
+ mode, op, types, args, kwargs
50
+ ):
51
+ from torch.export._trace import custom_triton_ops_decomposition_disabled
52
+
53
+ if custom_triton_ops_decomposition_disabled():
54
+ return mode.__torch_dispatch__(op, types, args, kwargs)
55
+ else:
56
+ with mode:
57
+ return fn(*args, **kwargs)
58
+
59
+ result.register_torch_dispatch(FunctionalTensorMode, functional_decomp)
60
+
61
+ return result
62
+
63
+ if fn is None:
64
+ return dec
65
+ else:
66
+ return dec(fn)
torch-ext/triton_layer_norm/utils/torch.py ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from typing import Callable
3
+
4
+
5
+ def custom_amp_decorator(dec: Callable, cuda_amp_deprecated: bool):
6
+ def decorator(*args, **kwargs):
7
+ if cuda_amp_deprecated:
8
+ kwargs["device_type"] = "cuda"
9
+ return dec(*args, **kwargs)
10
+ return decorator
11
+
12
+
13
+ if hasattr(torch.amp, "custom_fwd"): # type: ignore[attr-defined]
14
+ deprecated = True
15
+ from torch.amp import custom_fwd, custom_bwd # type: ignore[attr-defined]
16
+ else:
17
+ deprecated = False
18
+ from torch.cuda.amp import custom_fwd, custom_bwd
19
+
20
+ custom_fwd = custom_amp_decorator(custom_fwd, deprecated)
21
+ custom_bwd = custom_amp_decorator(custom_bwd, deprecated)