|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
from enum import Enum |
|
from typing import Optional |
|
|
|
import triton |
|
import triton.language as tl |
|
|
|
_sqrt2pi = math.sqrt(2.0 / math.pi) |
|
_sqrt1_2 = math.sqrt(1.0 / 2) |
|
_gaussian_pdf_normalization = 1.0 / math.sqrt(2 * math.pi) |
|
|
|
|
|
class Activation(str, Enum): |
|
SquaredReLU = "squared_relu" |
|
GeLU = "gelu" |
|
GeLUApprox = "gelu_approx" |
|
LeakyReLU = "leaky_relu" |
|
ReLU = "relu" |
|
|
|
|
|
def get_triton_activation_kernel(activation: Optional[Activation]): |
|
return ( |
|
{ |
|
Activation.ReLU: relu, |
|
Activation.LeakyReLU: leaky_relu, |
|
Activation.GeLU: gelu, |
|
Activation.GeLUApprox: gelu_approx, |
|
Activation.SquaredReLU: squared_relu, |
|
}[activation] |
|
if activation |
|
else None |
|
) |
|
|
|
|
|
def get_triton_activation_bwd_kernel(activation: Optional[Activation]): |
|
return ( |
|
{ |
|
Activation.ReLU: relu_grad, |
|
Activation.LeakyReLU: leaky_relu_grad, |
|
Activation.GeLU: gelu_grad, |
|
Activation.GeLUApprox: gelu_approx_grad, |
|
Activation.SquaredReLU: squared_relu_grad, |
|
}[activation] |
|
if activation |
|
else None |
|
) |
|
|
|
|
|
@triton.jit |
|
def tanh(x): |
|
|
|
return 2 * tl.sigmoid(2 * x) - 1 |
|
|
|
|
|
@triton.jit |
|
def cosh(x): |
|
exp_x = tl.exp(x) |
|
return (exp_x + 1.0 / exp_x) * 0.5 |
|
|
|
|
|
|
|
|
|
|
|
|
|
@triton.jit |
|
def relu(x): |
|
""" |
|
ReLU_ activation function |
|
|
|
.. _ReLU: https://pytorch.org/docs/stable/generated/torch.nn.ReLU.html |
|
""" |
|
zero = 0.0 |
|
return tl.where(x >= 0, x, zero.to(x.dtype)) |
|
|
|
|
|
@triton.jit |
|
def relu_grad(x): |
|
|
|
|
|
|
|
zero = 0.0 |
|
one = 1.0 |
|
return tl.where(x >= 0, one.to(x.dtype), zero.to(x.dtype)) |
|
|
|
|
|
@triton.jit |
|
def squared_relu(x): |
|
""" |
|
Squared ReLU activation, as proposed in the Primer_ paper. |
|
|
|
.. _Primer: https://arxiv.org/abs/2109.08668 |
|
""" |
|
x_ = relu(x) |
|
return (x_ * x_).to(x.dtype) |
|
|
|
|
|
@triton.jit |
|
def squared_relu_grad(x): |
|
return tl.where(x >= 0, 2.0 * x, 0.0) |
|
|
|
|
|
|
|
@triton.jit |
|
def leaky_relu(x): |
|
""" |
|
LeakyReLU_ activation |
|
|
|
.. _LeakyReLU: https://pytorch.org/docs/stable/generated/torch.nn.LeakyReLU.html |
|
""" |
|
scale = 0.01 + 0.0 |
|
scale = scale.to(x.dtype) |
|
return tl.where(x >= 0, x, scale * x) |
|
|
|
|
|
@triton.jit |
|
def leaky_relu_grad(x): |
|
min_grad = 0.01 |
|
max_grad = 1 |
|
|
|
min_grad = min_grad.to(x.dtype) |
|
max_grad = max_grad.to(x.dtype) |
|
|
|
return tl.where(x >= 0, max_grad, min_grad) |
|
|
|
|
|
@triton.jit |
|
def gelu(x): |
|
"""Gaussian Error Linear Unit (GELU)""" |
|
return x * 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) |
|
|
|
|
|
@triton.jit |
|
def gelu_grad(x): |
|
cdf = 0.5 * (1.0 + tl.libdevice.erf(x * _sqrt1_2)) |
|
pdf = tl.exp(-0.5 * x * x) * _gaussian_pdf_normalization |
|
return cdf + x * pdf |
|
|
|
|
|
@triton.jit |
|
def gelu_approx(x): |
|
""" |
|
GeLU_ activation - Gaussian error linear unit, with tanh approximation |
|
|
|
.. _GeLU: https://arxiv.org/pdf/1606.08415.pdf |
|
""" |
|
return 0.5 * x * (1.0 + tanh(_sqrt2pi * x * (1.0 + 0.044715 * x * x))) |
|
|
|
|
|
@triton.jit |
|
def gelu_approx_grad(x): |
|
|
|
|
|
tanh_out = tanh(0.79788456 * x * (1 + 0.044715 * x * x)) |
|
return 0.5 * x * ((1 - tanh_out * tanh_out) * (0.79788456 + 0.1070322243 * x * x)) + 0.5 * ( |
|
1 + tanh_out |
|
) |
|
|