File size: 6,434 Bytes
9c6594c |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 |
from __future__ import annotations
import copy
from typing import TypeVar
import torch
__all__ = [
"fuse_conv_bn_eval",
"fuse_conv_bn_weights",
"fuse_linear_bn_eval",
"fuse_linear_bn_weights",
]
ConvT = TypeVar("ConvT", bound="torch.nn.modules.conv._ConvNd")
LinearT = TypeVar("LinearT", bound="torch.nn.Linear")
def fuse_conv_bn_eval(
conv: ConvT,
bn: torch.nn.modules.batchnorm._BatchNorm,
transpose: bool = False,
) -> ConvT:
r"""Fuse a convolutional module and a BatchNorm module into a single, new convolutional module.
Args:
conv (torch.nn.modules.conv._ConvNd): A convolutional module.
bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
transpose (bool, optional): If True, transpose the convolutional weight. Defaults to False.
Returns:
torch.nn.modules.conv._ConvNd: The fused convolutional module.
.. note::
Both ``conv`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
"""
assert not (conv.training or bn.training), "Fusion only for eval!"
fused_conv = copy.deepcopy(conv)
assert bn.running_mean is not None and bn.running_var is not None
fused_conv.weight, fused_conv.bias = fuse_conv_bn_weights(
fused_conv.weight,
fused_conv.bias,
bn.running_mean,
bn.running_var,
bn.eps,
bn.weight,
bn.bias,
transpose,
)
return fused_conv
def fuse_conv_bn_weights(
conv_w: torch.Tensor,
conv_b: torch.Tensor | None,
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: torch.Tensor | None,
bn_b: torch.Tensor | None,
transpose: bool = False,
) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse convolutional module parameters and BatchNorm module parameters into new convolutional module parameters.
Args:
conv_w (torch.Tensor): Convolutional weight.
conv_b (Optional[torch.Tensor]): Convolutional bias.
bn_rm (torch.Tensor): BatchNorm running mean.
bn_rv (torch.Tensor): BatchNorm running variance.
bn_eps (float): BatchNorm epsilon.
bn_w (Optional[torch.Tensor]): BatchNorm weight.
bn_b (Optional[torch.Tensor]): BatchNorm bias.
transpose (bool, optional): If True, transpose the conv weight. Defaults to False.
Returns:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused convolutional weight and bias.
"""
conv_weight_dtype = conv_w.dtype
conv_bias_dtype = conv_b.dtype if conv_b is not None else conv_weight_dtype
if conv_b is None:
conv_b = torch.zeros_like(bn_rm)
if bn_w is None:
bn_w = torch.ones_like(bn_rm)
if bn_b is None:
bn_b = torch.zeros_like(bn_rm)
bn_var_rsqrt = torch.rsqrt(bn_rv + bn_eps)
if transpose:
shape = [1, -1] + [1] * (len(conv_w.shape) - 2)
else:
shape = [-1, 1] + [1] * (len(conv_w.shape) - 2)
fused_conv_w = (conv_w * (bn_w * bn_var_rsqrt).reshape(shape)).to(
dtype=conv_weight_dtype
)
fused_conv_b = ((conv_b - bn_rm) * bn_var_rsqrt * bn_w + bn_b).to(
dtype=conv_bias_dtype
)
return (
torch.nn.Parameter(fused_conv_w, conv_w.requires_grad),
torch.nn.Parameter(fused_conv_b, conv_b.requires_grad),
)
def fuse_linear_bn_eval(
linear: LinearT,
bn: torch.nn.modules.batchnorm._BatchNorm,
) -> LinearT:
r"""Fuse a linear module and a BatchNorm module into a single, new linear module.
Args:
linear (torch.nn.Linear): A Linear module.
bn (torch.nn.modules.batchnorm._BatchNorm): A BatchNorm module.
Returns:
torch.nn.Linear: The fused linear module.
.. note::
Both ``linear`` and ``bn`` must be in eval mode, and ``bn`` must have its running buffers computed.
"""
assert not (linear.training or bn.training), "Fusion only for eval!"
fused_linear = copy.deepcopy(linear)
"""
Linear-BN needs to be fused while preserving the shapes of linear weight/bias.
To preserve the shapes of linear weight/bias, the channel dim of bn needs to be broadcastable with the last dim of linear,
because bn operates over the channel dim, (N, C_in, H, W) while linear operates over the last dim, (*, H_in).
To be broadcastable, the number of features in bn and
the number of output features from linear must satisfy the following condition:
1. they are equal, or
2. the number of features in bn is 1
Otherwise, skip the folding path
"""
assert (
linear.out_features == bn.num_features or bn.num_features == 1
), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
assert bn.running_mean is not None and bn.running_var is not None
fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
fused_linear.weight,
fused_linear.bias,
bn.running_mean,
bn.running_var,
bn.eps,
bn.weight,
bn.bias,
)
return fused_linear
def fuse_linear_bn_weights(
linear_w: torch.Tensor,
linear_b: torch.Tensor | None,
bn_rm: torch.Tensor,
bn_rv: torch.Tensor,
bn_eps: float,
bn_w: torch.Tensor,
bn_b: torch.Tensor,
) -> tuple[torch.nn.Parameter, torch.nn.Parameter]:
r"""Fuse linear module parameters and BatchNorm module parameters into new linear module parameters.
Args:
linear_w (torch.Tensor): Linear weight.
linear_b (Optional[torch.Tensor]): Linear bias.
bn_rm (torch.Tensor): BatchNorm running mean.
bn_rv (torch.Tensor): BatchNorm running variance.
bn_eps (float): BatchNorm epsilon.
bn_w (torch.Tensor): BatchNorm weight.
bn_b (torch.Tensor): BatchNorm bias.
Returns:
Tuple[torch.nn.Parameter, torch.nn.Parameter]: Fused linear weight and bias.
"""
linear_weight_dtype = linear_w.dtype
linear_bias_dtype = linear_b.dtype if linear_b is not None else linear_weight_dtype
if linear_b is None:
linear_b = torch.zeros_like(bn_rm)
bn_scale = bn_w * torch.rsqrt(bn_rv + bn_eps)
fused_w = linear_w * bn_scale.unsqueeze(-1).to(dtype=linear_weight_dtype)
fused_b = ((linear_b - bn_rm) * bn_scale + bn_b).to(dtype=linear_bias_dtype)
return torch.nn.Parameter(fused_w, linear_w.requires_grad), torch.nn.Parameter(
fused_b, linear_b.requires_grad
)
|