|
|
|
"""This file contains utilities for initializing neural network parameters.""" |
|
import math |
|
import warnings |
|
from typing import Optional as _Optional |
|
|
|
import torch |
|
from torch import Tensor |
|
|
|
|
|
|
|
|
|
|
|
|
|
def _no_grad_uniform_(tensor, a, b, generator=None): |
|
with torch.no_grad(): |
|
return tensor.uniform_(a, b, generator=generator) |
|
|
|
|
|
def _no_grad_normal_(tensor, mean, std, generator=None): |
|
with torch.no_grad(): |
|
return tensor.normal_(mean, std, generator=generator) |
|
|
|
|
|
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None): |
|
|
|
def norm_cdf(x): |
|
|
|
return (1.0 + math.erf(x / math.sqrt(2.0))) / 2.0 |
|
|
|
if (mean < a - 2 * std) or (mean > b + 2 * std): |
|
warnings.warn( |
|
"mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " |
|
"The distribution of values may be incorrect.", |
|
stacklevel=2, |
|
) |
|
|
|
with torch.no_grad(): |
|
|
|
|
|
|
|
l = norm_cdf((a - mean) / std) |
|
u = norm_cdf((b - mean) / std) |
|
|
|
|
|
|
|
tensor.uniform_(2 * l - 1, 2 * u - 1, generator=generator) |
|
|
|
|
|
|
|
tensor.erfinv_() |
|
|
|
|
|
tensor.mul_(std * math.sqrt(2.0)) |
|
tensor.add_(mean) |
|
|
|
|
|
tensor.clamp_(min=a, max=b) |
|
return tensor |
|
|
|
|
|
def _no_grad_fill_(tensor, val): |
|
with torch.no_grad(): |
|
return tensor.fill_(val) |
|
|
|
|
|
def _no_grad_zero_(tensor): |
|
with torch.no_grad(): |
|
return tensor.zero_() |
|
|
|
|
|
def calculate_gain(nonlinearity, param=None): |
|
r"""Return the recommended gain value for the given nonlinearity function. |
|
|
|
The values are as follows: |
|
|
|
================= ==================================================== |
|
nonlinearity gain |
|
================= ==================================================== |
|
Linear / Identity :math:`1` |
|
Conv{1,2,3}D :math:`1` |
|
Sigmoid :math:`1` |
|
Tanh :math:`\frac{5}{3}` |
|
ReLU :math:`\sqrt{2}` |
|
Leaky Relu :math:`\sqrt{\frac{2}{1 + \text{negative\_slope}^2}}` |
|
SELU :math:`\frac{3}{4}` |
|
================= ==================================================== |
|
|
|
.. warning:: |
|
In order to implement `Self-Normalizing Neural Networks`_ , |
|
you should use ``nonlinearity='linear'`` instead of ``nonlinearity='selu'``. |
|
This gives the initial weights a variance of ``1 / N``, |
|
which is necessary to induce a stable fixed point in the forward pass. |
|
In contrast, the default gain for ``SELU`` sacrifices the normalization |
|
effect for more stable gradient flow in rectangular layers. |
|
|
|
Args: |
|
nonlinearity: the non-linear function (`nn.functional` name) |
|
param: optional parameter for the non-linear function |
|
|
|
Examples: |
|
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2 |
|
|
|
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html |
|
""" |
|
linear_fns = [ |
|
"linear", |
|
"conv1d", |
|
"conv2d", |
|
"conv3d", |
|
"conv_transpose1d", |
|
"conv_transpose2d", |
|
"conv_transpose3d", |
|
] |
|
if nonlinearity in linear_fns or nonlinearity == "sigmoid": |
|
return 1 |
|
elif nonlinearity == "tanh": |
|
return 5.0 / 3 |
|
elif nonlinearity == "relu": |
|
return math.sqrt(2.0) |
|
elif nonlinearity == "leaky_relu": |
|
if param is None: |
|
negative_slope = 0.01 |
|
elif ( |
|
not isinstance(param, bool) |
|
and isinstance(param, int) |
|
or isinstance(param, float) |
|
): |
|
|
|
negative_slope = param |
|
else: |
|
raise ValueError(f"negative_slope {param} not a valid number") |
|
return math.sqrt(2.0 / (1 + negative_slope**2)) |
|
elif nonlinearity == "selu": |
|
return ( |
|
3.0 / 4 |
|
) |
|
else: |
|
raise ValueError(f"Unsupported nonlinearity {nonlinearity}") |
|
|
|
|
|
def uniform_( |
|
tensor: Tensor, |
|
a: float = 0.0, |
|
b: float = 1.0, |
|
generator: _Optional[torch.Generator] = None, |
|
) -> Tensor: |
|
r"""Fill the input Tensor with values drawn from the uniform distribution. |
|
|
|
:math:`\mathcal{U}(a, b)`. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
a: the lower bound of the uniform distribution |
|
b: the upper bound of the uniform distribution |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.uniform_(w) |
|
""" |
|
if torch.overrides.has_torch_function_variadic(tensor): |
|
return torch.overrides.handle_torch_function( |
|
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator |
|
) |
|
return _no_grad_uniform_(tensor, a, b, generator) |
|
|
|
|
|
def normal_( |
|
tensor: Tensor, |
|
mean: float = 0.0, |
|
std: float = 1.0, |
|
generator: _Optional[torch.Generator] = None, |
|
) -> Tensor: |
|
r"""Fill the input Tensor with values drawn from the normal distribution. |
|
|
|
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mean: the mean of the normal distribution |
|
std: the standard deviation of the normal distribution |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.normal_(w) |
|
""" |
|
if torch.overrides.has_torch_function_variadic(tensor): |
|
return torch.overrides.handle_torch_function( |
|
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator |
|
) |
|
return _no_grad_normal_(tensor, mean, std, generator) |
|
|
|
|
|
def trunc_normal_( |
|
tensor: Tensor, |
|
mean: float = 0.0, |
|
std: float = 1.0, |
|
a: float = -2.0, |
|
b: float = 2.0, |
|
generator: _Optional[torch.Generator] = None, |
|
) -> Tensor: |
|
r"""Fill the input Tensor with values drawn from a truncated normal distribution. |
|
|
|
The values are effectively drawn from the |
|
normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` |
|
with values outside :math:`[a, b]` redrawn until they are within |
|
the bounds. The method used for generating the random values works |
|
best when :math:`a \leq \text{mean} \leq b`. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
mean: the mean of the normal distribution |
|
std: the standard deviation of the normal distribution |
|
a: the minimum cutoff value |
|
b: the maximum cutoff value |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.trunc_normal_(w) |
|
""" |
|
return _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=generator) |
|
|
|
|
|
def constant_(tensor: Tensor, val: float) -> Tensor: |
|
r"""Fill the input Tensor with the value :math:`\text{val}`. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
val: the value to fill the tensor with |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.constant_(w, 0.3) |
|
""" |
|
if torch.overrides.has_torch_function_variadic(tensor): |
|
return torch.overrides.handle_torch_function( |
|
constant_, (tensor,), tensor=tensor, val=val |
|
) |
|
return _no_grad_fill_(tensor, val) |
|
|
|
|
|
def ones_(tensor: Tensor) -> Tensor: |
|
r"""Fill the input Tensor with the scalar value `1`. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.ones_(w) |
|
""" |
|
return _no_grad_fill_(tensor, 1.0) |
|
|
|
|
|
def zeros_(tensor: Tensor) -> Tensor: |
|
r"""Fill the input Tensor with the scalar value `0`. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.zeros_(w) |
|
""" |
|
return _no_grad_zero_(tensor) |
|
|
|
|
|
def eye_(tensor): |
|
r"""Fill the 2-dimensional input `Tensor` with the identity matrix. |
|
|
|
Preserves the identity of the inputs in `Linear` layers, where as |
|
many inputs are preserved as possible. |
|
|
|
Args: |
|
tensor: a 2-dimensional `torch.Tensor` |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.eye_(w) |
|
""" |
|
if tensor.ndimension() != 2: |
|
raise ValueError("Only tensors with 2 dimensions are supported") |
|
|
|
with torch.no_grad(): |
|
torch.eye(*tensor.shape, out=tensor, requires_grad=tensor.requires_grad) |
|
return tensor |
|
|
|
|
|
def dirac_(tensor, groups=1): |
|
r"""Fill the {3, 4, 5}-dimensional input `Tensor` with the Dirac delta function. |
|
|
|
Preserves the identity of the inputs in `Convolutional` |
|
layers, where as many input channels are preserved as possible. In case |
|
of groups>1, each group of channels preserves identity |
|
|
|
Args: |
|
tensor: a {3, 4, 5}-dimensional `torch.Tensor` |
|
groups (int, optional): number of groups in the conv layer (default: 1) |
|
Examples: |
|
>>> w = torch.empty(3, 16, 5, 5) |
|
>>> nn.init.dirac_(w) |
|
>>> w = torch.empty(3, 24, 5, 5) |
|
>>> nn.init.dirac_(w, 3) |
|
""" |
|
dimensions = tensor.ndimension() |
|
if dimensions not in [3, 4, 5]: |
|
raise ValueError("Only tensors with 3, 4, or 5 dimensions are supported") |
|
|
|
sizes = tensor.size() |
|
|
|
if sizes[0] % groups != 0: |
|
raise ValueError("dim 0 must be divisible by groups") |
|
|
|
out_chans_per_grp = sizes[0] // groups |
|
min_dim = min(out_chans_per_grp, sizes[1]) |
|
|
|
with torch.no_grad(): |
|
tensor.zero_() |
|
|
|
for g in range(groups): |
|
for d in range(min_dim): |
|
if dimensions == 3: |
|
tensor[g * out_chans_per_grp + d, d, tensor.size(2) // 2] = 1 |
|
elif dimensions == 4: |
|
tensor[ |
|
g * out_chans_per_grp + d, |
|
d, |
|
tensor.size(2) // 2, |
|
tensor.size(3) // 2, |
|
] = 1 |
|
else: |
|
tensor[ |
|
g * out_chans_per_grp + d, |
|
d, |
|
tensor.size(2) // 2, |
|
tensor.size(3) // 2, |
|
tensor.size(4) // 2, |
|
] = 1 |
|
return tensor |
|
|
|
|
|
def _calculate_fan_in_and_fan_out(tensor): |
|
dimensions = tensor.dim() |
|
if dimensions < 2: |
|
raise ValueError( |
|
"Fan in and fan out can not be computed for tensor with fewer than 2 dimensions" |
|
) |
|
|
|
num_input_fmaps = tensor.size(1) |
|
num_output_fmaps = tensor.size(0) |
|
receptive_field_size = 1 |
|
if tensor.dim() > 2: |
|
|
|
|
|
for s in tensor.shape[2:]: |
|
receptive_field_size *= s |
|
fan_in = num_input_fmaps * receptive_field_size |
|
fan_out = num_output_fmaps * receptive_field_size |
|
|
|
return fan_in, fan_out |
|
|
|
|
|
def xavier_uniform_( |
|
tensor: Tensor, |
|
gain: float = 1.0, |
|
generator: _Optional[torch.Generator] = None, |
|
) -> Tensor: |
|
r"""Fill the input `Tensor` with values using a Xavier uniform distribution. |
|
|
|
The method is described in `Understanding the difficulty of training |
|
deep feedforward neural networks` - Glorot, X. & Bengio, Y. (2010). |
|
The resulting tensor will have values sampled from |
|
:math:`\mathcal{U}(-a, a)` where |
|
|
|
.. math:: |
|
a = \text{gain} \times \sqrt{\frac{6}{\text{fan\_in} + \text{fan\_out}}} |
|
|
|
Also known as Glorot initialization. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
gain: an optional scaling factor |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu')) |
|
|
|
Note: |
|
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming |
|
that the weight matrix is used in a transposed manner, |
|
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). |
|
This is important for correct initialization. |
|
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, |
|
pass in a transposed weight matrix, i.e. ``nn.init.xavier_uniform_(w.T, ...)``. |
|
""" |
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) |
|
a = math.sqrt(3.0) * std |
|
|
|
return _no_grad_uniform_(tensor, -a, a, generator) |
|
|
|
|
|
def xavier_normal_( |
|
tensor: Tensor, |
|
gain: float = 1.0, |
|
generator: _Optional[torch.Generator] = None, |
|
) -> Tensor: |
|
r"""Fill the input `Tensor` with values using a Xavier normal distribution. |
|
|
|
The method is described in `Understanding the difficulty of training deep feedforward |
|
neural networks` - Glorot, X. & Bengio, Y. (2010). The resulting tensor |
|
will have values sampled from :math:`\mathcal{N}(0, \text{std}^2)` where |
|
|
|
.. math:: |
|
\text{std} = \text{gain} \times \sqrt{\frac{2}{\text{fan\_in} + \text{fan\_out}}} |
|
|
|
Also known as Glorot initialization. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
gain: an optional scaling factor |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.xavier_normal_(w) |
|
|
|
Note: |
|
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming |
|
that the weight matrix is used in a transposed manner, |
|
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). |
|
This is important for correct initialization. |
|
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, |
|
pass in a transposed weight matrix, i.e. ``nn.init.xavier_normal_(w.T, ...)``. |
|
""" |
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
|
std = gain * math.sqrt(2.0 / float(fan_in + fan_out)) |
|
|
|
return _no_grad_normal_(tensor, 0.0, std, generator) |
|
|
|
|
|
def _calculate_correct_fan(tensor, mode): |
|
mode = mode.lower() |
|
valid_modes = ["fan_in", "fan_out"] |
|
if mode not in valid_modes: |
|
raise ValueError(f"Mode {mode} not supported, please use one of {valid_modes}") |
|
|
|
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor) |
|
return fan_in if mode == "fan_in" else fan_out |
|
|
|
|
|
def kaiming_uniform_( |
|
tensor: Tensor, |
|
a: float = 0, |
|
mode: str = "fan_in", |
|
nonlinearity: str = "leaky_relu", |
|
generator: _Optional[torch.Generator] = None, |
|
): |
|
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution. |
|
|
|
The method is described in `Delving deep into rectifiers: Surpassing |
|
human-level performance on ImageNet classification` - He, K. et al. (2015). |
|
The resulting tensor will have values sampled from |
|
:math:`\mathcal{U}(-\text{bound}, \text{bound})` where |
|
|
|
.. math:: |
|
\text{bound} = \text{gain} \times \sqrt{\frac{3}{\text{fan\_mode}}} |
|
|
|
Also known as He initialization. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
a: the negative slope of the rectifier used after this layer (only |
|
used with ``'leaky_relu'``) |
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` |
|
preserves the magnitude of the variance of the weights in the |
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the |
|
backwards pass. |
|
nonlinearity: the non-linear function (`nn.functional` name), |
|
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu') |
|
|
|
Note: |
|
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming |
|
that the weight matrix is used in a transposed manner, |
|
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). |
|
This is important for correct initialization. |
|
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, |
|
pass in a transposed weight matrix, i.e. ``nn.init.kaiming_uniform_(w.T, ...)``. |
|
""" |
|
if torch.overrides.has_torch_function_variadic(tensor): |
|
return torch.overrides.handle_torch_function( |
|
kaiming_uniform_, |
|
(tensor,), |
|
tensor=tensor, |
|
a=a, |
|
mode=mode, |
|
nonlinearity=nonlinearity, |
|
generator=generator, |
|
) |
|
|
|
if 0 in tensor.shape: |
|
warnings.warn("Initializing zero-element tensors is a no-op") |
|
return tensor |
|
fan = _calculate_correct_fan(tensor, mode) |
|
gain = calculate_gain(nonlinearity, a) |
|
std = gain / math.sqrt(fan) |
|
bound = math.sqrt(3.0) * std |
|
with torch.no_grad(): |
|
return tensor.uniform_(-bound, bound, generator=generator) |
|
|
|
|
|
def kaiming_normal_( |
|
tensor: Tensor, |
|
a: float = 0, |
|
mode: str = "fan_in", |
|
nonlinearity: str = "leaky_relu", |
|
generator: _Optional[torch.Generator] = None, |
|
): |
|
r"""Fill the input `Tensor` with values using a Kaiming normal distribution. |
|
|
|
The method is described in `Delving deep into rectifiers: Surpassing |
|
human-level performance on ImageNet classification` - He, K. et al. (2015). |
|
The resulting tensor will have values sampled from |
|
:math:`\mathcal{N}(0, \text{std}^2)` where |
|
|
|
.. math:: |
|
\text{std} = \frac{\text{gain}}{\sqrt{\text{fan\_mode}}} |
|
|
|
Also known as He initialization. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
a: the negative slope of the rectifier used after this layer (only |
|
used with ``'leaky_relu'``) |
|
mode: either ``'fan_in'`` (default) or ``'fan_out'``. Choosing ``'fan_in'`` |
|
preserves the magnitude of the variance of the weights in the |
|
forward pass. Choosing ``'fan_out'`` preserves the magnitudes in the |
|
backwards pass. |
|
nonlinearity: the non-linear function (`nn.functional` name), |
|
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default). |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu') |
|
|
|
Note: |
|
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming |
|
that the weight matrix is used in a transposed manner, |
|
(i.e., ``x @ w.T`` in ``Linear`` layers, where ``w.shape = [fan_out, fan_in]``). |
|
This is important for correct initialization. |
|
If you plan to use ``x @ w``, where ``w.shape = [fan_in, fan_out]``, |
|
pass in a transposed weight matrix, i.e. ``nn.init.kaiming_normal_(w.T, ...)``. |
|
""" |
|
if 0 in tensor.shape: |
|
warnings.warn("Initializing zero-element tensors is a no-op") |
|
return tensor |
|
fan = _calculate_correct_fan(tensor, mode) |
|
gain = calculate_gain(nonlinearity, a) |
|
std = gain / math.sqrt(fan) |
|
with torch.no_grad(): |
|
return tensor.normal_(0, std, generator=generator) |
|
|
|
|
|
def orthogonal_( |
|
tensor, |
|
gain=1, |
|
generator: _Optional[torch.Generator] = None, |
|
): |
|
r"""Fill the input `Tensor` with a (semi) orthogonal matrix. |
|
|
|
Described in `Exact solutions to the nonlinear dynamics of learning in deep |
|
linear neural networks` - Saxe, A. et al. (2013). The input tensor must have |
|
at least 2 dimensions, and for tensors with more than 2 dimensions the |
|
trailing dimensions are flattened. |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2` |
|
gain: optional scaling factor |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK) |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.orthogonal_(w) |
|
""" |
|
if tensor.ndimension() < 2: |
|
raise ValueError("Only tensors with 2 or more dimensions are supported") |
|
|
|
if tensor.numel() == 0: |
|
|
|
return tensor |
|
rows = tensor.size(0) |
|
cols = tensor.numel() // rows |
|
flattened = tensor.new_empty((rows, cols)).normal_(0, 1, generator=generator) |
|
|
|
if rows < cols: |
|
flattened.t_() |
|
|
|
|
|
q, r = torch.linalg.qr(flattened) |
|
|
|
d = torch.diag(r, 0) |
|
ph = d.sign() |
|
q *= ph |
|
|
|
if rows < cols: |
|
q.t_() |
|
|
|
with torch.no_grad(): |
|
tensor.view_as(q).copy_(q) |
|
tensor.mul_(gain) |
|
return tensor |
|
|
|
|
|
def sparse_( |
|
tensor, |
|
sparsity, |
|
std=0.01, |
|
generator: _Optional[torch.Generator] = None, |
|
): |
|
r"""Fill the 2D input `Tensor` as a sparse matrix. |
|
|
|
The non-zero elements will be drawn from the normal distribution |
|
:math:`\mathcal{N}(0, 0.01)`, as described in `Deep learning via |
|
Hessian-free optimization` - Martens, J. (2010). |
|
|
|
Args: |
|
tensor: an n-dimensional `torch.Tensor` |
|
sparsity: The fraction of elements in each column to be set to zero |
|
std: the standard deviation of the normal distribution used to generate |
|
the non-zero values |
|
generator: the torch Generator to sample from (default: None) |
|
|
|
Examples: |
|
>>> w = torch.empty(3, 5) |
|
>>> nn.init.sparse_(w, sparsity=0.1) |
|
""" |
|
if tensor.ndimension() != 2: |
|
raise ValueError("Only tensors with 2 dimensions are supported") |
|
|
|
rows, cols = tensor.shape |
|
num_zeros = int(math.ceil(sparsity * rows)) |
|
|
|
with torch.no_grad(): |
|
tensor.normal_(0, std, generator=generator) |
|
for col_idx in range(cols): |
|
row_indices = torch.randperm(rows) |
|
zero_indices = row_indices[:num_zeros] |
|
tensor[zero_indices, col_idx] = 0 |
|
return tensor |
|
|
|
|
|
|
|
def _make_deprecate(meth): |
|
new_name = meth.__name__ |
|
old_name = new_name[:-1] |
|
|
|
def deprecated_init(*args, **kwargs): |
|
warnings.warn( |
|
f"`nn.init.{old_name}` is now deprecated in favor of `nn.init.{new_name}`.", |
|
FutureWarning, |
|
stacklevel=2, |
|
) |
|
return meth(*args, **kwargs) |
|
|
|
deprecated_init.__doc__ = rf""" |
|
{old_name}(...) |
|
|
|
.. warning:: |
|
This method is now deprecated in favor of :func:`torch.nn.init.{new_name}`. |
|
|
|
See :func:`~torch.nn.init.{new_name}` for details.""" |
|
deprecated_init.__name__ = old_name |
|
return deprecated_init |
|
|
|
|
|
uniform = _make_deprecate(uniform_) |
|
normal = _make_deprecate(normal_) |
|
constant = _make_deprecate(constant_) |
|
eye = _make_deprecate(eye_) |
|
dirac = _make_deprecate(dirac_) |
|
xavier_uniform = _make_deprecate(xavier_uniform_) |
|
xavier_normal = _make_deprecate(xavier_normal_) |
|
kaiming_uniform = _make_deprecate(kaiming_uniform_) |
|
kaiming_normal = _make_deprecate(kaiming_normal_) |
|
orthogonal = _make_deprecate(orthogonal_) |
|
sparse = _make_deprecate(sparse_) |
|
|