File size: 4,053 Bytes
80b5db1
 
 
 
 
 
 
 
 
0f75957
02bea52
80b5db1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
02bea52
80b5db1
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
"""Triton layer normalization kernels

This kernel implements layers normalization using Triton. This kernel is from
the `flash-attention <https://github.com/Dao-AILab/flash-attention>`_ project.
"""

from typing import Optional

import torch

from . import layers
from .layer_norm import layer_norm_fn, layer_norm_linear_fn, rms_norm_fn


def layer_norm(
    x: torch.Tensor,
    weight: torch.Tensor,
    bias: torch.Tensor,
    residual: Optional[torch.Tensor] = None,
    x1: Optional[torch.Tensor] = None,
    weight1: Optional[torch.Tensor] = None,
    bias1: Optional[torch.Tensor] = None,
    eps: float = 1e-6,
    dropout_p: float = 0.0,
    rowscale=None,
    prenorm: bool = False,
    residual_in_fp32: bool = False,
    is_rms_norm: bool = False,
    return_dropout_mask: bool = False,
    out: Optional[torch.Tensor] = None,
    residual_out: Optional[torch.Tensor] = None,
):
    """
    Apply layer normalization to the input tensor with Triton acceleration.

    Args:
        x (`torch.Tensor`):
            Input tensor to normalize.
        weight (`torch.Tensor`):
            Scale parameter for normalization.
        bias (`torch.Tensor`):
            Shift parameter for normalization.
        residual (`torch.Tensor`, *optional*):
            Optional residual tensor to add to the input before normalization.
        x1 (`torch.Tensor`, *optional*):
            Optional second input tensor to combine with `x`. When provided, the function
            first adds `x1` to `x` and then applies normalization.
        weight1 (`torch.Tensor`, *optional*):
            Scale parameter for the second normalization.
        bias1 (`torch.Tensor`, *optional*):
            Shift parameter for the second normalization.
        eps (`float`, *optional*, defaults to 1e-6):
            Small constant added for numerical stability in normalization.
        dropout_p (`float`, *optional*, defaults to 0.0):
            Dropout probability. If greater than 0, applies dropout to the input before
            normalization and residual addition.
        rowscale (`torch.Tensor`, *optional*):
            Optional scaling factor applied to each row of the input tensor.
            Not compatible with the use of `x1`.
        prenorm (`bool`, *optional*, defaults to False):
            If True, returns both the normalized output and the unnormalized input+residual.
        residual_in_fp32 (`bool`, *optional*, defaults to False):
            If True, performs the residual connection in FP32 precision.
        is_rms_norm (`bool`, *optional*, defaults to False):
            If True, uses RMS normalization instead of layer normalization.
        return_dropout_mask (`bool`, *optional*, defaults to False):
            If True, returns the dropout mask used for the computation.
        out (`torch.Tensor`, *optional*):
            Output tensor for the normalized result. If `None`, a new tensor is allocated.
        residual_out (`torch.Tensor`, *optional*):
            Output tensor for the residual result when using prenorm. If `None`, a new tensor
            is allocated when needed.

    Returns:
        `torch.Tensor` or tuple of `torch.Tensor`:
            - The normalized input.
            - The second normalization of the input if `weight1` is provided.
            - The residual tensor if `prenorm` is set.
            - The dropout mask if `return_dropout_mask` is set.
            - The dropout mask for `x1` if `x1` is provided and `return_dropout_mask` is set.
    """
    return layer_norm_fn(
        x,
        weight,
        bias,
        residual,
        x1,
        weight1,
        bias1,
        eps,
        dropout_p,
        rowscale,
        prenorm,
        residual_in_fp32,
        is_rms_norm,
        return_dropout_mask,
        out=out,
        residual_out=residual_out,
    )


__kernel_metadata__ = {
    "license": "bsd-3-clause",
}


__all__ = [
    "__kernel_metadata__",
    "layers",
    "layer_norm",
    "layer_norm_fn",
    "layer_norm_linear_fn",
    "rms_norm_fn",
]