|
|
|
|
|
import math |
|
import numbers |
|
import warnings |
|
import weakref |
|
from typing import Optional, overload |
|
from typing_extensions import deprecated |
|
|
|
import torch |
|
from torch import _VF, Tensor |
|
from torch.nn import init |
|
from torch.nn.parameter import Parameter |
|
from torch.nn.utils.rnn import PackedSequence |
|
|
|
from .module import Module |
|
|
|
|
|
__all__ = [ |
|
"RNNBase", |
|
"RNN", |
|
"LSTM", |
|
"GRU", |
|
"RNNCellBase", |
|
"RNNCell", |
|
"LSTMCell", |
|
"GRUCell", |
|
] |
|
|
|
_rnn_impls = { |
|
"RNN_TANH": _VF.rnn_tanh, |
|
"RNN_RELU": _VF.rnn_relu, |
|
} |
|
|
|
|
|
def _apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: |
|
return tensor.index_select(dim, permutation) |
|
|
|
|
|
@deprecated( |
|
"`apply_permutation` is deprecated, please use `tensor.index_select(dim, permutation)` instead", |
|
category=FutureWarning, |
|
) |
|
def apply_permutation(tensor: Tensor, permutation: Tensor, dim: int = 1) -> Tensor: |
|
return _apply_permutation(tensor, permutation, dim) |
|
|
|
|
|
class RNNBase(Module): |
|
r"""Base class for RNN modules (RNN, LSTM, GRU). |
|
|
|
Implements aspects of RNNs shared by the RNN, LSTM, and GRU classes, such as module initialization |
|
and utility methods for parameter storage management. |
|
|
|
.. note:: |
|
The forward method is not implemented by the RNNBase class. |
|
|
|
.. note:: |
|
LSTM and GRU classes override some methods implemented by RNNBase. |
|
""" |
|
|
|
__constants__ = [ |
|
"mode", |
|
"input_size", |
|
"hidden_size", |
|
"num_layers", |
|
"bias", |
|
"batch_first", |
|
"dropout", |
|
"bidirectional", |
|
"proj_size", |
|
] |
|
__jit_unused_properties__ = ["all_weights"] |
|
|
|
mode: str |
|
input_size: int |
|
hidden_size: int |
|
num_layers: int |
|
bias: bool |
|
batch_first: bool |
|
dropout: float |
|
bidirectional: bool |
|
proj_size: int |
|
|
|
def __init__( |
|
self, |
|
mode: str, |
|
input_size: int, |
|
hidden_size: int, |
|
num_layers: int = 1, |
|
bias: bool = True, |
|
batch_first: bool = False, |
|
dropout: float = 0.0, |
|
bidirectional: bool = False, |
|
proj_size: int = 0, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.mode = mode |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.num_layers = num_layers |
|
self.bias = bias |
|
self.batch_first = batch_first |
|
self.dropout = float(dropout) |
|
self.bidirectional = bidirectional |
|
self.proj_size = proj_size |
|
self._flat_weight_refs: list[Optional[weakref.ReferenceType[Parameter]]] = [] |
|
num_directions = 2 if bidirectional else 1 |
|
|
|
if ( |
|
not isinstance(dropout, numbers.Number) |
|
or not 0 <= dropout <= 1 |
|
or isinstance(dropout, bool) |
|
): |
|
raise ValueError( |
|
"dropout should be a number in range [0, 1] " |
|
"representing the probability of an element being " |
|
"zeroed" |
|
) |
|
if dropout > 0 and num_layers == 1: |
|
warnings.warn( |
|
"dropout option adds dropout after all but last " |
|
"recurrent layer, so non-zero dropout expects " |
|
f"num_layers greater than 1, but got dropout={dropout} and " |
|
f"num_layers={num_layers}" |
|
) |
|
|
|
if not isinstance(hidden_size, int): |
|
raise TypeError( |
|
f"hidden_size should be of type int, got: {type(hidden_size).__name__}" |
|
) |
|
if hidden_size <= 0: |
|
raise ValueError("hidden_size must be greater than zero") |
|
if num_layers <= 0: |
|
raise ValueError("num_layers must be greater than zero") |
|
if proj_size < 0: |
|
raise ValueError( |
|
"proj_size should be a positive integer or zero to disable projections" |
|
) |
|
if proj_size >= hidden_size: |
|
raise ValueError("proj_size has to be smaller than hidden_size") |
|
|
|
if mode == "LSTM": |
|
gate_size = 4 * hidden_size |
|
elif mode == "GRU": |
|
gate_size = 3 * hidden_size |
|
elif mode == "RNN_TANH": |
|
gate_size = hidden_size |
|
elif mode == "RNN_RELU": |
|
gate_size = hidden_size |
|
else: |
|
raise ValueError("Unrecognized RNN mode: " + mode) |
|
|
|
self._flat_weights_names = [] |
|
self._all_weights = [] |
|
for layer in range(num_layers): |
|
for direction in range(num_directions): |
|
real_hidden_size = proj_size if proj_size > 0 else hidden_size |
|
layer_input_size = ( |
|
input_size if layer == 0 else real_hidden_size * num_directions |
|
) |
|
|
|
w_ih = Parameter( |
|
torch.empty((gate_size, layer_input_size), **factory_kwargs) |
|
) |
|
w_hh = Parameter( |
|
torch.empty((gate_size, real_hidden_size), **factory_kwargs) |
|
) |
|
b_ih = Parameter(torch.empty(gate_size, **factory_kwargs)) |
|
|
|
|
|
b_hh = Parameter(torch.empty(gate_size, **factory_kwargs)) |
|
layer_params: tuple[Tensor, ...] = () |
|
if self.proj_size == 0: |
|
if bias: |
|
layer_params = (w_ih, w_hh, b_ih, b_hh) |
|
else: |
|
layer_params = (w_ih, w_hh) |
|
else: |
|
w_hr = Parameter( |
|
torch.empty((proj_size, hidden_size), **factory_kwargs) |
|
) |
|
if bias: |
|
layer_params = (w_ih, w_hh, b_ih, b_hh, w_hr) |
|
else: |
|
layer_params = (w_ih, w_hh, w_hr) |
|
|
|
suffix = "_reverse" if direction == 1 else "" |
|
param_names = ["weight_ih_l{}{}", "weight_hh_l{}{}"] |
|
if bias: |
|
param_names += ["bias_ih_l{}{}", "bias_hh_l{}{}"] |
|
if self.proj_size > 0: |
|
param_names += ["weight_hr_l{}{}"] |
|
param_names = [x.format(layer, suffix) for x in param_names] |
|
|
|
for name, param in zip(param_names, layer_params): |
|
setattr(self, name, param) |
|
self._flat_weights_names.extend(param_names) |
|
self._all_weights.append(param_names) |
|
|
|
self._init_flat_weights() |
|
|
|
self.reset_parameters() |
|
|
|
def _init_flat_weights(self): |
|
self._flat_weights = [ |
|
getattr(self, wn) if hasattr(self, wn) else None |
|
for wn in self._flat_weights_names |
|
] |
|
self._flat_weight_refs = [ |
|
weakref.ref(w) if w is not None else None for w in self._flat_weights |
|
] |
|
self.flatten_parameters() |
|
|
|
def __setattr__(self, attr, value): |
|
if hasattr(self, "_flat_weights_names") and attr in self._flat_weights_names: |
|
|
|
idx = self._flat_weights_names.index(attr) |
|
self._flat_weights[idx] = value |
|
super().__setattr__(attr, value) |
|
|
|
def flatten_parameters(self) -> None: |
|
"""Reset parameter data pointer so that they can use faster code paths. |
|
|
|
Right now, this works only if the module is on the GPU and cuDNN is enabled. |
|
Otherwise, it's a no-op. |
|
""" |
|
|
|
if len(self._flat_weights) != len(self._flat_weights_names): |
|
return |
|
|
|
for w in self._flat_weights: |
|
if not isinstance(w, Tensor): |
|
return |
|
|
|
|
|
|
|
first_fw = self._flat_weights[0] |
|
dtype = first_fw.dtype |
|
for fw in self._flat_weights: |
|
if ( |
|
not isinstance(fw, Tensor) |
|
or not (fw.dtype == dtype) |
|
or not fw.is_cuda |
|
or not torch.backends.cudnn.is_acceptable(fw) |
|
): |
|
return |
|
|
|
|
|
|
|
|
|
|
|
unique_data_ptrs = { |
|
p.data_ptr() for p in self._flat_weights |
|
} |
|
if len(unique_data_ptrs) != len(self._flat_weights): |
|
return |
|
|
|
with torch.cuda.device_of(first_fw): |
|
import torch.backends.cudnn.rnn as rnn |
|
|
|
|
|
|
|
with torch.no_grad(): |
|
if torch._use_cudnn_rnn_flatten_weight(): |
|
num_weights = 4 if self.bias else 2 |
|
if self.proj_size > 0: |
|
num_weights += 1 |
|
torch._cudnn_rnn_flatten_weight( |
|
self._flat_weights, |
|
num_weights, |
|
self.input_size, |
|
rnn.get_cudnn_mode(self.mode), |
|
self.hidden_size, |
|
self.proj_size, |
|
self.num_layers, |
|
self.batch_first, |
|
bool(self.bidirectional), |
|
) |
|
|
|
def _apply(self, fn, recurse=True): |
|
self._flat_weight_refs = [] |
|
ret = super()._apply(fn, recurse) |
|
|
|
|
|
|
|
|
|
self._init_flat_weights() |
|
|
|
return ret |
|
|
|
def reset_parameters(self) -> None: |
|
stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 |
|
for weight in self.parameters(): |
|
init.uniform_(weight, -stdv, stdv) |
|
|
|
def check_input(self, input: Tensor, batch_sizes: Optional[Tensor]) -> None: |
|
if not torch.jit.is_scripting(): |
|
if ( |
|
input.dtype != self._flat_weights[0].dtype |
|
and not torch._C._is_any_autocast_enabled() |
|
): |
|
raise ValueError( |
|
f"input must have the type {self._flat_weights[0].dtype}, got type {input.dtype}" |
|
) |
|
expected_input_dim = 2 if batch_sizes is not None else 3 |
|
if input.dim() != expected_input_dim: |
|
raise RuntimeError( |
|
f"input must have {expected_input_dim} dimensions, got {input.dim()}" |
|
) |
|
if self.input_size != input.size(-1): |
|
raise RuntimeError( |
|
f"input.size(-1) must be equal to input_size. Expected {self.input_size}, got {input.size(-1)}" |
|
) |
|
|
|
def get_expected_hidden_size( |
|
self, input: Tensor, batch_sizes: Optional[Tensor] |
|
) -> tuple[int, int, int]: |
|
if batch_sizes is not None: |
|
mini_batch = int(batch_sizes[0]) |
|
else: |
|
mini_batch = input.size(0) if self.batch_first else input.size(1) |
|
num_directions = 2 if self.bidirectional else 1 |
|
if self.proj_size > 0: |
|
expected_hidden_size = ( |
|
self.num_layers * num_directions, |
|
mini_batch, |
|
self.proj_size, |
|
) |
|
else: |
|
expected_hidden_size = ( |
|
self.num_layers * num_directions, |
|
mini_batch, |
|
self.hidden_size, |
|
) |
|
return expected_hidden_size |
|
|
|
def check_hidden_size( |
|
self, |
|
hx: Tensor, |
|
expected_hidden_size: tuple[int, int, int], |
|
msg: str = "Expected hidden size {}, got {}", |
|
) -> None: |
|
if hx.size() != expected_hidden_size: |
|
raise RuntimeError(msg.format(expected_hidden_size, list(hx.size()))) |
|
|
|
def _weights_have_changed(self): |
|
|
|
|
|
weights_changed = False |
|
for ref, name in zip(self._flat_weight_refs, self._flat_weights_names): |
|
weight = getattr(self, name) if hasattr(self, name) else None |
|
if weight is not None and ref is not None and ref() is not weight: |
|
weights_changed = True |
|
break |
|
return weights_changed |
|
|
|
def check_forward_args( |
|
self, input: Tensor, hidden: Tensor, batch_sizes: Optional[Tensor] |
|
): |
|
self.check_input(input, batch_sizes) |
|
expected_hidden_size = self.get_expected_hidden_size(input, batch_sizes) |
|
|
|
self.check_hidden_size(hidden, expected_hidden_size) |
|
|
|
def permute_hidden(self, hx: Tensor, permutation: Optional[Tensor]): |
|
if permutation is None: |
|
return hx |
|
return _apply_permutation(hx, permutation) |
|
|
|
def extra_repr(self) -> str: |
|
s = "{input_size}, {hidden_size}" |
|
if self.proj_size != 0: |
|
s += ", proj_size={proj_size}" |
|
if self.num_layers != 1: |
|
s += ", num_layers={num_layers}" |
|
if self.bias is not True: |
|
s += ", bias={bias}" |
|
if self.batch_first is not False: |
|
s += ", batch_first={batch_first}" |
|
if self.dropout != 0: |
|
s += ", dropout={dropout}" |
|
if self.bidirectional is not False: |
|
s += ", bidirectional={bidirectional}" |
|
return s.format(**self.__dict__) |
|
|
|
def _update_flat_weights(self): |
|
if not torch.jit.is_scripting(): |
|
if self._weights_have_changed(): |
|
self._init_flat_weights() |
|
|
|
def __getstate__(self): |
|
|
|
self._update_flat_weights() |
|
|
|
state = self.__dict__.copy() |
|
del state["_flat_weight_refs"] |
|
return state |
|
|
|
def __setstate__(self, d): |
|
super().__setstate__(d) |
|
if "all_weights" in d: |
|
self._all_weights = d["all_weights"] |
|
|
|
|
|
|
|
if "proj_size" not in d: |
|
self.proj_size = 0 |
|
|
|
if not isinstance(self._all_weights[0][0], str): |
|
num_layers = self.num_layers |
|
num_directions = 2 if self.bidirectional else 1 |
|
self._flat_weights_names = [] |
|
self._all_weights = [] |
|
for layer in range(num_layers): |
|
for direction in range(num_directions): |
|
suffix = "_reverse" if direction == 1 else "" |
|
weights = [ |
|
"weight_ih_l{}{}", |
|
"weight_hh_l{}{}", |
|
"bias_ih_l{}{}", |
|
"bias_hh_l{}{}", |
|
"weight_hr_l{}{}", |
|
] |
|
weights = [x.format(layer, suffix) for x in weights] |
|
if self.bias: |
|
if self.proj_size > 0: |
|
self._all_weights += [weights] |
|
self._flat_weights_names.extend(weights) |
|
else: |
|
self._all_weights += [weights[:4]] |
|
self._flat_weights_names.extend(weights[:4]) |
|
else: |
|
if self.proj_size > 0: |
|
self._all_weights += [weights[:2]] + [weights[-1:]] |
|
self._flat_weights_names.extend( |
|
weights[:2] + [weights[-1:]] |
|
) |
|
else: |
|
self._all_weights += [weights[:2]] |
|
self._flat_weights_names.extend(weights[:2]) |
|
self._flat_weights = [ |
|
getattr(self, wn) if hasattr(self, wn) else None |
|
for wn in self._flat_weights_names |
|
] |
|
|
|
self._flat_weight_refs = [ |
|
weakref.ref(w) if w is not None else None for w in self._flat_weights |
|
] |
|
|
|
@property |
|
def all_weights(self) -> list[list[Parameter]]: |
|
return [ |
|
[getattr(self, weight) for weight in weights] |
|
for weights in self._all_weights |
|
] |
|
|
|
def _replicate_for_data_parallel(self): |
|
replica = super()._replicate_for_data_parallel() |
|
|
|
|
|
replica._flat_weights = replica._flat_weights[:] |
|
replica._flat_weights_names = replica._flat_weights_names[:] |
|
return replica |
|
|
|
|
|
class RNN(RNNBase): |
|
r"""__init__(input_size,hidden_size,num_layers=1,nonlinearity='tanh',bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) |
|
|
|
Apply a multi-layer Elman RNN with :math:`\tanh` or :math:`\text{ReLU}` |
|
non-linearity to an input sequence. For each element in the input sequence, |
|
each layer computes the following function: |
|
|
|
.. math:: |
|
h_t = \tanh(x_t W_{ih}^T + b_{ih} + h_{t-1}W_{hh}^T + b_{hh}) |
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is |
|
the input at time `t`, and :math:`h_{(t-1)}` is the hidden state of the |
|
previous layer at time `t-1` or the initial hidden state at time `0`. |
|
If :attr:`nonlinearity` is ``'relu'``, then :math:`\text{ReLU}` is used instead of :math:`\tanh`. |
|
|
|
.. code-block:: python |
|
|
|
# Efficient implementation equivalent to the following with bidirectional=False |
|
def forward(x, hx=None): |
|
if batch_first: |
|
x = x.transpose(0, 1) |
|
seq_len, batch_size, _ = x.size() |
|
if hx is None: |
|
hx = torch.zeros(num_layers, batch_size, hidden_size) |
|
h_t_minus_1 = hx |
|
h_t = hx |
|
output = [] |
|
for t in range(seq_len): |
|
for layer in range(num_layers): |
|
h_t[layer] = torch.tanh( |
|
x[t] @ weight_ih[layer].T |
|
+ bias_ih[layer] |
|
+ h_t_minus_1[layer] @ weight_hh[layer].T |
|
+ bias_hh[layer] |
|
) |
|
output.append(h_t[-1]) |
|
h_t_minus_1 = h_t |
|
output = torch.stack(output) |
|
if batch_first: |
|
output = output.transpose(0, 1) |
|
return output, h_t |
|
|
|
Args: |
|
input_size: The number of expected features in the input `x` |
|
hidden_size: The number of features in the hidden state `h` |
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` |
|
would mean stacking two RNNs together to form a `stacked RNN`, |
|
with the second RNN taking in outputs of the first RNN and |
|
computing the final results. Default: 1 |
|
nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` |
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. |
|
Default: ``True`` |
|
batch_first: If ``True``, then the input and output tensors are provided |
|
as `(batch, seq, feature)` instead of `(seq, batch, feature)`. |
|
Note that this does not apply to hidden or cell states. See the |
|
Inputs/Outputs sections below for details. Default: ``False`` |
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each |
|
RNN layer except the last layer, with dropout probability equal to |
|
:attr:`dropout`. Default: 0 |
|
bidirectional: If ``True``, becomes a bidirectional RNN. Default: ``False`` |
|
|
|
Inputs: input, hx |
|
* **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, |
|
:math:`(L, N, H_{in})` when ``batch_first=False`` or |
|
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of |
|
the input sequence. The input can also be a packed variable length sequence. |
|
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or |
|
:func:`torch.nn.utils.rnn.pack_sequence` for details. |
|
* **hx**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or |
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the initial hidden |
|
state for the input sequence batch. Defaults to zeros if not provided. |
|
|
|
where: |
|
|
|
.. math:: |
|
\begin{aligned} |
|
N ={} & \text{batch size} \\ |
|
L ={} & \text{sequence length} \\ |
|
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ |
|
H_{in} ={} & \text{input\_size} \\ |
|
H_{out} ={} & \text{hidden\_size} |
|
\end{aligned} |
|
|
|
Outputs: output, h_n |
|
* **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, |
|
:math:`(L, N, D * H_{out})` when ``batch_first=False`` or |
|
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features |
|
`(h_t)` from the last layer of the RNN, for each `t`. If a |
|
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output |
|
will also be a packed sequence. |
|
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or |
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state |
|
for each element in the batch. |
|
|
|
Attributes: |
|
weight_ih_l[k]: the learnable input-hidden weights of the k-th layer, |
|
of shape `(hidden_size, input_size)` for `k = 0`. Otherwise, the shape is |
|
`(hidden_size, num_directions * hidden_size)` |
|
weight_hh_l[k]: the learnable hidden-hidden weights of the k-th layer, |
|
of shape `(hidden_size, hidden_size)` |
|
bias_ih_l[k]: the learnable input-hidden bias of the k-th layer, |
|
of shape `(hidden_size)` |
|
bias_hh_l[k]: the learnable hidden-hidden bias of the k-th layer, |
|
of shape `(hidden_size)` |
|
|
|
.. note:: |
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` |
|
where :math:`k = \frac{1}{\text{hidden\_size}}` |
|
|
|
.. note:: |
|
For bidirectional RNNs, forward and backward are directions 0 and 1 respectively. |
|
Example of splitting the output layers when ``batch_first=False``: |
|
``output.view(seq_len, batch, num_directions, hidden_size)``. |
|
|
|
.. note:: |
|
``batch_first`` argument is ignored for unbatched inputs. |
|
|
|
.. include:: ../cudnn_rnn_determinism.rst |
|
|
|
.. include:: ../cudnn_persistent_rnn.rst |
|
|
|
Examples:: |
|
|
|
>>> rnn = nn.RNN(10, 20, 2) |
|
>>> input = torch.randn(5, 3, 10) |
|
>>> h0 = torch.randn(2, 3, 20) |
|
>>> output, hn = rnn(input, h0) |
|
""" |
|
|
|
@overload |
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
num_layers: int = 1, |
|
nonlinearity: str = "tanh", |
|
bias: bool = True, |
|
batch_first: bool = False, |
|
dropout: float = 0.0, |
|
bidirectional: bool = False, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
... |
|
|
|
@overload |
|
def __init__(self, *args, **kwargs): |
|
... |
|
|
|
def __init__(self, *args, **kwargs): |
|
if "proj_size" in kwargs: |
|
raise ValueError( |
|
"proj_size argument is only supported for LSTM, not RNN or GRU" |
|
) |
|
if len(args) > 3: |
|
self.nonlinearity = args[3] |
|
args = args[:3] + args[4:] |
|
else: |
|
self.nonlinearity = kwargs.pop("nonlinearity", "tanh") |
|
if self.nonlinearity == "tanh": |
|
mode = "RNN_TANH" |
|
elif self.nonlinearity == "relu": |
|
mode = "RNN_RELU" |
|
else: |
|
raise ValueError( |
|
f"Unknown nonlinearity '{self.nonlinearity}'. Select from 'tanh' or 'relu'." |
|
) |
|
super().__init__(mode, *args, **kwargs) |
|
|
|
@overload |
|
@torch._jit_internal._overload_method |
|
def forward( |
|
self, input: Tensor, hx: Optional[Tensor] = None |
|
) -> tuple[Tensor, Tensor]: |
|
pass |
|
|
|
@overload |
|
@torch._jit_internal._overload_method |
|
def forward( |
|
self, input: PackedSequence, hx: Optional[Tensor] = None |
|
) -> tuple[PackedSequence, Tensor]: |
|
pass |
|
|
|
def forward(self, input, hx=None): |
|
self._update_flat_weights() |
|
|
|
num_directions = 2 if self.bidirectional else 1 |
|
orig_input = input |
|
|
|
if isinstance(orig_input, PackedSequence): |
|
input, batch_sizes, sorted_indices, unsorted_indices = input |
|
max_batch_size = batch_sizes[0] |
|
|
|
if hx is None: |
|
hx = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
self.hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
else: |
|
|
|
|
|
hx = self.permute_hidden(hx, sorted_indices) |
|
else: |
|
batch_sizes = None |
|
if input.dim() not in (2, 3): |
|
raise ValueError( |
|
f"RNN: Expected input to be 2D or 3D, got {input.dim()}D tensor instead" |
|
) |
|
is_batched = input.dim() == 3 |
|
batch_dim = 0 if self.batch_first else 1 |
|
if not is_batched: |
|
input = input.unsqueeze(batch_dim) |
|
if hx is not None: |
|
if hx.dim() != 2: |
|
raise RuntimeError( |
|
f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" |
|
) |
|
hx = hx.unsqueeze(1) |
|
else: |
|
if hx is not None and hx.dim() != 3: |
|
raise RuntimeError( |
|
f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" |
|
) |
|
max_batch_size = input.size(0) if self.batch_first else input.size(1) |
|
sorted_indices = None |
|
unsorted_indices = None |
|
if hx is None: |
|
hx = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
self.hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
else: |
|
|
|
|
|
hx = self.permute_hidden(hx, sorted_indices) |
|
|
|
assert hx is not None |
|
self.check_forward_args(input, hx, batch_sizes) |
|
assert self.mode == "RNN_TANH" or self.mode == "RNN_RELU" |
|
if batch_sizes is None: |
|
if self.mode == "RNN_TANH": |
|
result = _VF.rnn_tanh( |
|
input, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
self.batch_first, |
|
) |
|
else: |
|
result = _VF.rnn_relu( |
|
input, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
self.batch_first, |
|
) |
|
else: |
|
if self.mode == "RNN_TANH": |
|
result = _VF.rnn_tanh( |
|
input, |
|
batch_sizes, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
) |
|
else: |
|
result = _VF.rnn_relu( |
|
input, |
|
batch_sizes, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
) |
|
|
|
output = result[0] |
|
hidden = result[1] |
|
|
|
if isinstance(orig_input, PackedSequence): |
|
output_packed = PackedSequence( |
|
output, batch_sizes, sorted_indices, unsorted_indices |
|
) |
|
return output_packed, self.permute_hidden(hidden, unsorted_indices) |
|
|
|
if not is_batched: |
|
output = output.squeeze(batch_dim) |
|
hidden = hidden.squeeze(1) |
|
|
|
return output, self.permute_hidden(hidden, unsorted_indices) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LSTM(RNNBase): |
|
r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,proj_size=0,device=None,dtype=None) |
|
|
|
Apply a multi-layer long short-term memory (LSTM) RNN to an input sequence. |
|
For each element in the input sequence, each layer computes the following |
|
function: |
|
|
|
.. math:: |
|
\begin{array}{ll} \\ |
|
i_t = \sigma(W_{ii} x_t + b_{ii} + W_{hi} h_{t-1} + b_{hi}) \\ |
|
f_t = \sigma(W_{if} x_t + b_{if} + W_{hf} h_{t-1} + b_{hf}) \\ |
|
g_t = \tanh(W_{ig} x_t + b_{ig} + W_{hg} h_{t-1} + b_{hg}) \\ |
|
o_t = \sigma(W_{io} x_t + b_{io} + W_{ho} h_{t-1} + b_{ho}) \\ |
|
c_t = f_t \odot c_{t-1} + i_t \odot g_t \\ |
|
h_t = o_t \odot \tanh(c_t) \\ |
|
\end{array} |
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`c_t` is the cell |
|
state at time `t`, :math:`x_t` is the input at time `t`, :math:`h_{t-1}` |
|
is the hidden state of the layer at time `t-1` or the initial hidden |
|
state at time `0`, and :math:`i_t`, :math:`f_t`, :math:`g_t`, |
|
:math:`o_t` are the input, forget, cell, and output gates, respectively. |
|
:math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. |
|
|
|
In a multilayer LSTM, the input :math:`x^{(l)}_t` of the :math:`l` -th layer |
|
(:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by |
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random |
|
variable which is :math:`0` with probability :attr:`dropout`. |
|
|
|
If ``proj_size > 0`` is specified, LSTM with projections will be used. This changes |
|
the LSTM cell in the following way. First, the dimension of :math:`h_t` will be changed from |
|
``hidden_size`` to ``proj_size`` (dimensions of :math:`W_{hi}` will be changed accordingly). |
|
Second, the output hidden state of each layer will be multiplied by a learnable projection |
|
matrix: :math:`h_t = W_{hr}h_t`. Note that as a consequence of this, the output |
|
of LSTM network will be of different shape as well. See Inputs/Outputs sections below for exact |
|
dimensions of all variables. You can find more details in https://arxiv.org/abs/1402.1128. |
|
|
|
Args: |
|
input_size: The number of expected features in the input `x` |
|
hidden_size: The number of features in the hidden state `h` |
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` |
|
would mean stacking two LSTMs together to form a `stacked LSTM`, |
|
with the second LSTM taking in outputs of the first LSTM and |
|
computing the final results. Default: 1 |
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. |
|
Default: ``True`` |
|
batch_first: If ``True``, then the input and output tensors are provided |
|
as `(batch, seq, feature)` instead of `(seq, batch, feature)`. |
|
Note that this does not apply to hidden or cell states. See the |
|
Inputs/Outputs sections below for details. Default: ``False`` |
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each |
|
LSTM layer except the last layer, with dropout probability equal to |
|
:attr:`dropout`. Default: 0 |
|
bidirectional: If ``True``, becomes a bidirectional LSTM. Default: ``False`` |
|
proj_size: If ``> 0``, will use LSTM with projections of corresponding size. Default: 0 |
|
|
|
Inputs: input, (h_0, c_0) |
|
* **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, |
|
:math:`(L, N, H_{in})` when ``batch_first=False`` or |
|
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of |
|
the input sequence. The input can also be a packed variable length sequence. |
|
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or |
|
:func:`torch.nn.utils.rnn.pack_sequence` for details. |
|
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or |
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the |
|
initial hidden state for each element in the input sequence. |
|
Defaults to zeros if (h_0, c_0) is not provided. |
|
* **c_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or |
|
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the |
|
initial cell state for each element in the input sequence. |
|
Defaults to zeros if (h_0, c_0) is not provided. |
|
|
|
where: |
|
|
|
.. math:: |
|
\begin{aligned} |
|
N ={} & \text{batch size} \\ |
|
L ={} & \text{sequence length} \\ |
|
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ |
|
H_{in} ={} & \text{input\_size} \\ |
|
H_{cell} ={} & \text{hidden\_size} \\ |
|
H_{out} ={} & \text{proj\_size if } \text{proj\_size}>0 \text{ otherwise hidden\_size} \\ |
|
\end{aligned} |
|
|
|
Outputs: output, (h_n, c_n) |
|
* **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, |
|
:math:`(L, N, D * H_{out})` when ``batch_first=False`` or |
|
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features |
|
`(h_t)` from the last layer of the LSTM, for each `t`. If a |
|
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output |
|
will also be a packed sequence. When ``bidirectional=True``, `output` will contain |
|
a concatenation of the forward and reverse hidden states at each time step in the sequence. |
|
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` for unbatched input or |
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the |
|
final hidden state for each element in the sequence. When ``bidirectional=True``, |
|
`h_n` will contain a concatenation of the final forward and reverse hidden states, respectively. |
|
* **c_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{cell})` for unbatched input or |
|
:math:`(D * \text{num\_layers}, N, H_{cell})` containing the |
|
final cell state for each element in the sequence. When ``bidirectional=True``, |
|
`c_n` will contain a concatenation of the final forward and reverse cell states, respectively. |
|
|
|
Attributes: |
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer |
|
`(W_ii|W_if|W_ig|W_io)`, of shape `(4*hidden_size, input_size)` for `k = 0`. |
|
Otherwise, the shape is `(4*hidden_size, num_directions * hidden_size)`. If |
|
``proj_size > 0`` was specified, the shape will be |
|
`(4*hidden_size, num_directions * proj_size)` for `k > 0` |
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer |
|
`(W_hi|W_hf|W_hg|W_ho)`, of shape `(4*hidden_size, hidden_size)`. If ``proj_size > 0`` |
|
was specified, the shape will be `(4*hidden_size, proj_size)`. |
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer |
|
`(b_ii|b_if|b_ig|b_io)`, of shape `(4*hidden_size)` |
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer |
|
`(b_hi|b_hf|b_hg|b_ho)`, of shape `(4*hidden_size)` |
|
weight_hr_l[k] : the learnable projection weights of the :math:`\text{k}^{th}` layer |
|
of shape `(proj_size, hidden_size)`. Only present when ``proj_size > 0`` was |
|
specified. |
|
weight_ih_l[k]_reverse: Analogous to `weight_ih_l[k]` for the reverse direction. |
|
Only present when ``bidirectional=True``. |
|
weight_hh_l[k]_reverse: Analogous to `weight_hh_l[k]` for the reverse direction. |
|
Only present when ``bidirectional=True``. |
|
bias_ih_l[k]_reverse: Analogous to `bias_ih_l[k]` for the reverse direction. |
|
Only present when ``bidirectional=True``. |
|
bias_hh_l[k]_reverse: Analogous to `bias_hh_l[k]` for the reverse direction. |
|
Only present when ``bidirectional=True``. |
|
weight_hr_l[k]_reverse: Analogous to `weight_hr_l[k]` for the reverse direction. |
|
Only present when ``bidirectional=True`` and ``proj_size > 0`` was specified. |
|
|
|
.. note:: |
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` |
|
where :math:`k = \frac{1}{\text{hidden\_size}}` |
|
|
|
.. note:: |
|
For bidirectional LSTMs, forward and backward are directions 0 and 1 respectively. |
|
Example of splitting the output layers when ``batch_first=False``: |
|
``output.view(seq_len, batch, num_directions, hidden_size)``. |
|
|
|
.. note:: |
|
For bidirectional LSTMs, `h_n` is not equivalent to the last element of `output`; the |
|
former contains the final forward and reverse hidden states, while the latter contains the |
|
final forward hidden state and the initial reverse hidden state. |
|
|
|
.. note:: |
|
``batch_first`` argument is ignored for unbatched inputs. |
|
|
|
.. note:: |
|
``proj_size`` should be smaller than ``hidden_size``. |
|
|
|
.. include:: ../cudnn_rnn_determinism.rst |
|
|
|
.. include:: ../cudnn_persistent_rnn.rst |
|
|
|
Examples:: |
|
|
|
>>> rnn = nn.LSTM(10, 20, 2) |
|
>>> input = torch.randn(5, 3, 10) |
|
>>> h0 = torch.randn(2, 3, 20) |
|
>>> c0 = torch.randn(2, 3, 20) |
|
>>> output, (hn, cn) = rnn(input, (h0, c0)) |
|
""" |
|
|
|
@overload |
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
num_layers: int = 1, |
|
bias: bool = True, |
|
batch_first: bool = False, |
|
dropout: float = 0.0, |
|
bidirectional: bool = False, |
|
proj_size: int = 0, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
... |
|
|
|
@overload |
|
def __init__(self, *args, **kwargs): |
|
... |
|
|
|
def __init__(self, *args, **kwargs): |
|
super().__init__("LSTM", *args, **kwargs) |
|
|
|
def get_expected_cell_size( |
|
self, input: Tensor, batch_sizes: Optional[Tensor] |
|
) -> tuple[int, int, int]: |
|
if batch_sizes is not None: |
|
mini_batch = int(batch_sizes[0]) |
|
else: |
|
mini_batch = input.size(0) if self.batch_first else input.size(1) |
|
num_directions = 2 if self.bidirectional else 1 |
|
expected_hidden_size = ( |
|
self.num_layers * num_directions, |
|
mini_batch, |
|
self.hidden_size, |
|
) |
|
return expected_hidden_size |
|
|
|
|
|
|
|
def check_forward_args( |
|
self, |
|
input: Tensor, |
|
hidden: tuple[Tensor, Tensor], |
|
batch_sizes: Optional[Tensor], |
|
): |
|
self.check_input(input, batch_sizes) |
|
self.check_hidden_size( |
|
hidden[0], |
|
self.get_expected_hidden_size(input, batch_sizes), |
|
"Expected hidden[0] size {}, got {}", |
|
) |
|
self.check_hidden_size( |
|
hidden[1], |
|
self.get_expected_cell_size(input, batch_sizes), |
|
"Expected hidden[1] size {}, got {}", |
|
) |
|
|
|
|
|
def permute_hidden( |
|
self, |
|
hx: tuple[Tensor, Tensor], |
|
permutation: Optional[Tensor], |
|
) -> tuple[Tensor, Tensor]: |
|
if permutation is None: |
|
return hx |
|
return _apply_permutation(hx[0], permutation), _apply_permutation( |
|
hx[1], permutation |
|
) |
|
|
|
|
|
@overload |
|
@torch._jit_internal._overload_method |
|
def forward( |
|
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None |
|
) -> tuple[Tensor, tuple[Tensor, Tensor]]: |
|
pass |
|
|
|
|
|
@overload |
|
@torch._jit_internal._overload_method |
|
def forward( |
|
self, input: PackedSequence, hx: Optional[tuple[Tensor, Tensor]] = None |
|
) -> tuple[PackedSequence, tuple[Tensor, Tensor]]: |
|
pass |
|
|
|
def forward(self, input, hx=None): |
|
self._update_flat_weights() |
|
|
|
orig_input = input |
|
|
|
batch_sizes = None |
|
num_directions = 2 if self.bidirectional else 1 |
|
real_hidden_size = self.proj_size if self.proj_size > 0 else self.hidden_size |
|
if isinstance(orig_input, PackedSequence): |
|
input, batch_sizes, sorted_indices, unsorted_indices = input |
|
max_batch_size = batch_sizes[0] |
|
if hx is None: |
|
h_zeros = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
real_hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
c_zeros = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
self.hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
hx = (h_zeros, c_zeros) |
|
else: |
|
|
|
|
|
hx = self.permute_hidden(hx, sorted_indices) |
|
else: |
|
if input.dim() not in (2, 3): |
|
raise ValueError( |
|
f"LSTM: Expected input to be 2D or 3D, got {input.dim()}D instead" |
|
) |
|
is_batched = input.dim() == 3 |
|
batch_dim = 0 if self.batch_first else 1 |
|
if not is_batched: |
|
input = input.unsqueeze(batch_dim) |
|
max_batch_size = input.size(0) if self.batch_first else input.size(1) |
|
sorted_indices = None |
|
unsorted_indices = None |
|
if hx is None: |
|
h_zeros = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
real_hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
c_zeros = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
self.hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
hx = (h_zeros, c_zeros) |
|
self.check_forward_args(input, hx, batch_sizes) |
|
else: |
|
if is_batched: |
|
if hx[0].dim() != 3 or hx[1].dim() != 3: |
|
msg = ( |
|
"For batched 3-D input, hx and cx should " |
|
f"also be 3-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" |
|
) |
|
raise RuntimeError(msg) |
|
else: |
|
if hx[0].dim() != 2 or hx[1].dim() != 2: |
|
msg = ( |
|
"For unbatched 2-D input, hx and cx should " |
|
f"also be 2-D but got ({hx[0].dim()}-D, {hx[1].dim()}-D) tensors" |
|
) |
|
raise RuntimeError(msg) |
|
hx = (hx[0].unsqueeze(1), hx[1].unsqueeze(1)) |
|
|
|
|
|
self.check_forward_args(input, hx, batch_sizes) |
|
hx = self.permute_hidden(hx, sorted_indices) |
|
|
|
if batch_sizes is None: |
|
result = _VF.lstm( |
|
input, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
self.batch_first, |
|
) |
|
else: |
|
result = _VF.lstm( |
|
input, |
|
batch_sizes, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
) |
|
output = result[0] |
|
hidden = result[1:] |
|
|
|
if isinstance(orig_input, PackedSequence): |
|
output_packed = PackedSequence( |
|
output, batch_sizes, sorted_indices, unsorted_indices |
|
) |
|
return output_packed, self.permute_hidden(hidden, unsorted_indices) |
|
else: |
|
if not is_batched: |
|
output = output.squeeze(batch_dim) |
|
hidden = (hidden[0].squeeze(1), hidden[1].squeeze(1)) |
|
return output, self.permute_hidden(hidden, unsorted_indices) |
|
|
|
|
|
class GRU(RNNBase): |
|
r"""__init__(input_size,hidden_size,num_layers=1,bias=True,batch_first=False,dropout=0.0,bidirectional=False,device=None,dtype=None) |
|
|
|
Apply a multi-layer gated recurrent unit (GRU) RNN to an input sequence. |
|
For each element in the input sequence, each layer computes the following |
|
function: |
|
|
|
.. math:: |
|
\begin{array}{ll} |
|
r_t = \sigma(W_{ir} x_t + b_{ir} + W_{hr} h_{(t-1)} + b_{hr}) \\ |
|
z_t = \sigma(W_{iz} x_t + b_{iz} + W_{hz} h_{(t-1)} + b_{hz}) \\ |
|
n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) \\ |
|
h_t = (1 - z_t) \odot n_t + z_t \odot h_{(t-1)} |
|
\end{array} |
|
|
|
where :math:`h_t` is the hidden state at time `t`, :math:`x_t` is the input |
|
at time `t`, :math:`h_{(t-1)}` is the hidden state of the layer |
|
at time `t-1` or the initial hidden state at time `0`, and :math:`r_t`, |
|
:math:`z_t`, :math:`n_t` are the reset, update, and new gates, respectively. |
|
:math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. |
|
|
|
In a multilayer GRU, the input :math:`x^{(l)}_t` of the :math:`l` -th layer |
|
(:math:`l \ge 2`) is the hidden state :math:`h^{(l-1)}_t` of the previous layer multiplied by |
|
dropout :math:`\delta^{(l-1)}_t` where each :math:`\delta^{(l-1)}_t` is a Bernoulli random |
|
variable which is :math:`0` with probability :attr:`dropout`. |
|
|
|
Args: |
|
input_size: The number of expected features in the input `x` |
|
hidden_size: The number of features in the hidden state `h` |
|
num_layers: Number of recurrent layers. E.g., setting ``num_layers=2`` |
|
would mean stacking two GRUs together to form a `stacked GRU`, |
|
with the second GRU taking in outputs of the first GRU and |
|
computing the final results. Default: 1 |
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. |
|
Default: ``True`` |
|
batch_first: If ``True``, then the input and output tensors are provided |
|
as `(batch, seq, feature)` instead of `(seq, batch, feature)`. |
|
Note that this does not apply to hidden or cell states. See the |
|
Inputs/Outputs sections below for details. Default: ``False`` |
|
dropout: If non-zero, introduces a `Dropout` layer on the outputs of each |
|
GRU layer except the last layer, with dropout probability equal to |
|
:attr:`dropout`. Default: 0 |
|
bidirectional: If ``True``, becomes a bidirectional GRU. Default: ``False`` |
|
|
|
Inputs: input, h_0 |
|
* **input**: tensor of shape :math:`(L, H_{in})` for unbatched input, |
|
:math:`(L, N, H_{in})` when ``batch_first=False`` or |
|
:math:`(N, L, H_{in})` when ``batch_first=True`` containing the features of |
|
the input sequence. The input can also be a packed variable length sequence. |
|
See :func:`torch.nn.utils.rnn.pack_padded_sequence` or |
|
:func:`torch.nn.utils.rnn.pack_sequence` for details. |
|
* **h_0**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or |
|
:math:`(D * \text{num\_layers}, N, H_{out})` |
|
containing the initial hidden state for the input sequence. Defaults to zeros if not provided. |
|
|
|
where: |
|
|
|
.. math:: |
|
\begin{aligned} |
|
N ={} & \text{batch size} \\ |
|
L ={} & \text{sequence length} \\ |
|
D ={} & 2 \text{ if bidirectional=True otherwise } 1 \\ |
|
H_{in} ={} & \text{input\_size} \\ |
|
H_{out} ={} & \text{hidden\_size} |
|
\end{aligned} |
|
|
|
Outputs: output, h_n |
|
* **output**: tensor of shape :math:`(L, D * H_{out})` for unbatched input, |
|
:math:`(L, N, D * H_{out})` when ``batch_first=False`` or |
|
:math:`(N, L, D * H_{out})` when ``batch_first=True`` containing the output features |
|
`(h_t)` from the last layer of the GRU, for each `t`. If a |
|
:class:`torch.nn.utils.rnn.PackedSequence` has been given as the input, the output |
|
will also be a packed sequence. |
|
* **h_n**: tensor of shape :math:`(D * \text{num\_layers}, H_{out})` or |
|
:math:`(D * \text{num\_layers}, N, H_{out})` containing the final hidden state |
|
for the input sequence. |
|
|
|
Attributes: |
|
weight_ih_l[k] : the learnable input-hidden weights of the :math:`\text{k}^{th}` layer |
|
(W_ir|W_iz|W_in), of shape `(3*hidden_size, input_size)` for `k = 0`. |
|
Otherwise, the shape is `(3*hidden_size, num_directions * hidden_size)` |
|
weight_hh_l[k] : the learnable hidden-hidden weights of the :math:`\text{k}^{th}` layer |
|
(W_hr|W_hz|W_hn), of shape `(3*hidden_size, hidden_size)` |
|
bias_ih_l[k] : the learnable input-hidden bias of the :math:`\text{k}^{th}` layer |
|
(b_ir|b_iz|b_in), of shape `(3*hidden_size)` |
|
bias_hh_l[k] : the learnable hidden-hidden bias of the :math:`\text{k}^{th}` layer |
|
(b_hr|b_hz|b_hn), of shape `(3*hidden_size)` |
|
|
|
.. note:: |
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` |
|
where :math:`k = \frac{1}{\text{hidden\_size}}` |
|
|
|
.. note:: |
|
For bidirectional GRUs, forward and backward are directions 0 and 1 respectively. |
|
Example of splitting the output layers when ``batch_first=False``: |
|
``output.view(seq_len, batch, num_directions, hidden_size)``. |
|
|
|
.. note:: |
|
``batch_first`` argument is ignored for unbatched inputs. |
|
|
|
.. note:: |
|
The calculation of new gate :math:`n_t` subtly differs from the original paper and other frameworks. |
|
In the original implementation, the Hadamard product :math:`(\odot)` between :math:`r_t` and the |
|
previous hidden state :math:`h_{(t-1)}` is done before the multiplication with the weight matrix |
|
`W` and addition of bias: |
|
|
|
.. math:: |
|
\begin{aligned} |
|
n_t = \tanh(W_{in} x_t + b_{in} + W_{hn} ( r_t \odot h_{(t-1)} ) + b_{hn}) |
|
\end{aligned} |
|
|
|
This is in contrast to PyTorch implementation, which is done after :math:`W_{hn} h_{(t-1)}` |
|
|
|
.. math:: |
|
\begin{aligned} |
|
n_t = \tanh(W_{in} x_t + b_{in} + r_t \odot (W_{hn} h_{(t-1)}+ b_{hn})) |
|
\end{aligned} |
|
|
|
This implementation differs on purpose for efficiency. |
|
|
|
.. include:: ../cudnn_persistent_rnn.rst |
|
|
|
Examples:: |
|
|
|
>>> rnn = nn.GRU(10, 20, 2) |
|
>>> input = torch.randn(5, 3, 10) |
|
>>> h0 = torch.randn(2, 3, 20) |
|
>>> output, hn = rnn(input, h0) |
|
""" |
|
|
|
@overload |
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
num_layers: int = 1, |
|
bias: bool = True, |
|
batch_first: bool = False, |
|
dropout: float = 0.0, |
|
bidirectional: bool = False, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
... |
|
|
|
@overload |
|
def __init__(self, *args, **kwargs): |
|
... |
|
|
|
def __init__(self, *args, **kwargs): |
|
if "proj_size" in kwargs: |
|
raise ValueError( |
|
"proj_size argument is only supported for LSTM, not RNN or GRU" |
|
) |
|
super().__init__("GRU", *args, **kwargs) |
|
|
|
@overload |
|
@torch._jit_internal._overload_method |
|
def forward( |
|
self, input: Tensor, hx: Optional[Tensor] = None |
|
) -> tuple[Tensor, Tensor]: |
|
pass |
|
|
|
@overload |
|
@torch._jit_internal._overload_method |
|
def forward( |
|
self, input: PackedSequence, hx: Optional[Tensor] = None |
|
) -> tuple[PackedSequence, Tensor]: |
|
pass |
|
|
|
def forward(self, input, hx=None): |
|
self._update_flat_weights() |
|
|
|
orig_input = input |
|
|
|
if isinstance(orig_input, PackedSequence): |
|
input, batch_sizes, sorted_indices, unsorted_indices = input |
|
max_batch_size = batch_sizes[0] |
|
if hx is None: |
|
num_directions = 2 if self.bidirectional else 1 |
|
hx = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
self.hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
else: |
|
|
|
|
|
hx = self.permute_hidden(hx, sorted_indices) |
|
else: |
|
batch_sizes = None |
|
if input.dim() not in (2, 3): |
|
raise ValueError( |
|
f"GRU: Expected input to be 2D or 3D, got {input.dim()}D instead" |
|
) |
|
is_batched = input.dim() == 3 |
|
batch_dim = 0 if self.batch_first else 1 |
|
if not is_batched: |
|
input = input.unsqueeze(batch_dim) |
|
if hx is not None: |
|
if hx.dim() != 2: |
|
raise RuntimeError( |
|
f"For unbatched 2-D input, hx should also be 2-D but got {hx.dim()}-D tensor" |
|
) |
|
hx = hx.unsqueeze(1) |
|
else: |
|
if hx is not None and hx.dim() != 3: |
|
raise RuntimeError( |
|
f"For batched 3-D input, hx should also be 3-D but got {hx.dim()}-D tensor" |
|
) |
|
max_batch_size = input.size(0) if self.batch_first else input.size(1) |
|
sorted_indices = None |
|
unsorted_indices = None |
|
if hx is None: |
|
num_directions = 2 if self.bidirectional else 1 |
|
hx = torch.zeros( |
|
self.num_layers * num_directions, |
|
max_batch_size, |
|
self.hidden_size, |
|
dtype=input.dtype, |
|
device=input.device, |
|
) |
|
else: |
|
|
|
|
|
hx = self.permute_hidden(hx, sorted_indices) |
|
|
|
self.check_forward_args(input, hx, batch_sizes) |
|
if batch_sizes is None: |
|
result = _VF.gru( |
|
input, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
self.batch_first, |
|
) |
|
else: |
|
result = _VF.gru( |
|
input, |
|
batch_sizes, |
|
hx, |
|
self._flat_weights, |
|
self.bias, |
|
self.num_layers, |
|
self.dropout, |
|
self.training, |
|
self.bidirectional, |
|
) |
|
output = result[0] |
|
hidden = result[1] |
|
|
|
|
|
if isinstance(orig_input, PackedSequence): |
|
output_packed = PackedSequence( |
|
output, batch_sizes, sorted_indices, unsorted_indices |
|
) |
|
return output_packed, self.permute_hidden(hidden, unsorted_indices) |
|
else: |
|
if not is_batched: |
|
output = output.squeeze(batch_dim) |
|
hidden = hidden.squeeze(1) |
|
|
|
return output, self.permute_hidden(hidden, unsorted_indices) |
|
|
|
|
|
class RNNCellBase(Module): |
|
__constants__ = ["input_size", "hidden_size", "bias"] |
|
|
|
input_size: int |
|
hidden_size: int |
|
bias: bool |
|
weight_ih: Tensor |
|
weight_hh: Tensor |
|
|
|
|
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
bias: bool, |
|
num_chunks: int, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__() |
|
self.input_size = input_size |
|
self.hidden_size = hidden_size |
|
self.bias = bias |
|
self.weight_ih = Parameter( |
|
torch.empty((num_chunks * hidden_size, input_size), **factory_kwargs) |
|
) |
|
self.weight_hh = Parameter( |
|
torch.empty((num_chunks * hidden_size, hidden_size), **factory_kwargs) |
|
) |
|
if bias: |
|
self.bias_ih = Parameter( |
|
torch.empty(num_chunks * hidden_size, **factory_kwargs) |
|
) |
|
self.bias_hh = Parameter( |
|
torch.empty(num_chunks * hidden_size, **factory_kwargs) |
|
) |
|
else: |
|
self.register_parameter("bias_ih", None) |
|
self.register_parameter("bias_hh", None) |
|
|
|
self.reset_parameters() |
|
|
|
def extra_repr(self) -> str: |
|
s = "{input_size}, {hidden_size}" |
|
if "bias" in self.__dict__ and self.bias is not True: |
|
s += ", bias={bias}" |
|
if "nonlinearity" in self.__dict__ and self.nonlinearity != "tanh": |
|
s += ", nonlinearity={nonlinearity}" |
|
return s.format(**self.__dict__) |
|
|
|
def reset_parameters(self) -> None: |
|
stdv = 1.0 / math.sqrt(self.hidden_size) if self.hidden_size > 0 else 0 |
|
for weight in self.parameters(): |
|
init.uniform_(weight, -stdv, stdv) |
|
|
|
|
|
class RNNCell(RNNCellBase): |
|
r"""An Elman RNN cell with tanh or ReLU non-linearity. |
|
|
|
.. math:: |
|
|
|
h' = \tanh(W_{ih} x + b_{ih} + W_{hh} h + b_{hh}) |
|
|
|
If :attr:`nonlinearity` is `'relu'`, then ReLU is used in place of tanh. |
|
|
|
Args: |
|
input_size: The number of expected features in the input `x` |
|
hidden_size: The number of features in the hidden state `h` |
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and `b_hh`. |
|
Default: ``True`` |
|
nonlinearity: The non-linearity to use. Can be either ``'tanh'`` or ``'relu'``. Default: ``'tanh'`` |
|
|
|
Inputs: input, hidden |
|
- **input**: tensor containing input features |
|
- **hidden**: tensor containing the initial hidden state |
|
Defaults to zero if not provided. |
|
|
|
Outputs: h' |
|
- **h'** of shape `(batch, hidden_size)`: tensor containing the next hidden state |
|
for each element in the batch |
|
|
|
Shape: |
|
- input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where |
|
:math:`H_{in}` = `input_size`. |
|
- hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden |
|
state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. |
|
- output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. |
|
|
|
Attributes: |
|
weight_ih: the learnable input-hidden weights, of shape |
|
`(hidden_size, input_size)` |
|
weight_hh: the learnable hidden-hidden weights, of shape |
|
`(hidden_size, hidden_size)` |
|
bias_ih: the learnable input-hidden bias, of shape `(hidden_size)` |
|
bias_hh: the learnable hidden-hidden bias, of shape `(hidden_size)` |
|
|
|
.. note:: |
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` |
|
where :math:`k = \frac{1}{\text{hidden\_size}}` |
|
|
|
Examples:: |
|
|
|
>>> rnn = nn.RNNCell(10, 20) |
|
>>> input = torch.randn(6, 3, 10) |
|
>>> hx = torch.randn(3, 20) |
|
>>> output = [] |
|
>>> for i in range(6): |
|
... hx = rnn(input[i], hx) |
|
... output.append(hx) |
|
""" |
|
|
|
__constants__ = ["input_size", "hidden_size", "bias", "nonlinearity"] |
|
nonlinearity: str |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
bias: bool = True, |
|
nonlinearity: str = "tanh", |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__(input_size, hidden_size, bias, num_chunks=1, **factory_kwargs) |
|
self.nonlinearity = nonlinearity |
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: |
|
if input.dim() not in (1, 2): |
|
raise ValueError( |
|
f"RNNCell: Expected input to be 1D or 2D, got {input.dim()}D instead" |
|
) |
|
if hx is not None and hx.dim() not in (1, 2): |
|
raise ValueError( |
|
f"RNNCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" |
|
) |
|
is_batched = input.dim() == 2 |
|
if not is_batched: |
|
input = input.unsqueeze(0) |
|
|
|
if hx is None: |
|
hx = torch.zeros( |
|
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device |
|
) |
|
else: |
|
hx = hx.unsqueeze(0) if not is_batched else hx |
|
|
|
if self.nonlinearity == "tanh": |
|
ret = _VF.rnn_tanh_cell( |
|
input, |
|
hx, |
|
self.weight_ih, |
|
self.weight_hh, |
|
self.bias_ih, |
|
self.bias_hh, |
|
) |
|
elif self.nonlinearity == "relu": |
|
ret = _VF.rnn_relu_cell( |
|
input, |
|
hx, |
|
self.weight_ih, |
|
self.weight_hh, |
|
self.bias_ih, |
|
self.bias_hh, |
|
) |
|
else: |
|
ret = input |
|
raise RuntimeError(f"Unknown nonlinearity: {self.nonlinearity}") |
|
|
|
if not is_batched: |
|
ret = ret.squeeze(0) |
|
|
|
return ret |
|
|
|
|
|
class LSTMCell(RNNCellBase): |
|
r"""A long short-term memory (LSTM) cell. |
|
|
|
.. math:: |
|
|
|
\begin{array}{ll} |
|
i = \sigma(W_{ii} x + b_{ii} + W_{hi} h + b_{hi}) \\ |
|
f = \sigma(W_{if} x + b_{if} + W_{hf} h + b_{hf}) \\ |
|
g = \tanh(W_{ig} x + b_{ig} + W_{hg} h + b_{hg}) \\ |
|
o = \sigma(W_{io} x + b_{io} + W_{ho} h + b_{ho}) \\ |
|
c' = f \odot c + i \odot g \\ |
|
h' = o \odot \tanh(c') \\ |
|
\end{array} |
|
|
|
where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. |
|
|
|
Args: |
|
input_size: The number of expected features in the input `x` |
|
hidden_size: The number of features in the hidden state `h` |
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and |
|
`b_hh`. Default: ``True`` |
|
|
|
Inputs: input, (h_0, c_0) |
|
- **input** of shape `(batch, input_size)` or `(input_size)`: tensor containing input features |
|
- **h_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial hidden state |
|
- **c_0** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the initial cell state |
|
|
|
If `(h_0, c_0)` is not provided, both **h_0** and **c_0** default to zero. |
|
|
|
Outputs: (h_1, c_1) |
|
- **h_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next hidden state |
|
- **c_1** of shape `(batch, hidden_size)` or `(hidden_size)`: tensor containing the next cell state |
|
|
|
Attributes: |
|
weight_ih: the learnable input-hidden weights, of shape |
|
`(4*hidden_size, input_size)` |
|
weight_hh: the learnable hidden-hidden weights, of shape |
|
`(4*hidden_size, hidden_size)` |
|
bias_ih: the learnable input-hidden bias, of shape `(4*hidden_size)` |
|
bias_hh: the learnable hidden-hidden bias, of shape `(4*hidden_size)` |
|
|
|
.. note:: |
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` |
|
where :math:`k = \frac{1}{\text{hidden\_size}}` |
|
|
|
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward. |
|
|
|
Examples:: |
|
|
|
>>> rnn = nn.LSTMCell(10, 20) # (input_size, hidden_size) |
|
>>> input = torch.randn(2, 3, 10) # (time_steps, batch, input_size) |
|
>>> hx = torch.randn(3, 20) # (batch, hidden_size) |
|
>>> cx = torch.randn(3, 20) |
|
>>> output = [] |
|
>>> for i in range(input.size()[0]): |
|
... hx, cx = rnn(input[i], (hx, cx)) |
|
... output.append(hx) |
|
>>> output = torch.stack(output, dim=0) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
bias: bool = True, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__(input_size, hidden_size, bias, num_chunks=4, **factory_kwargs) |
|
|
|
def forward( |
|
self, input: Tensor, hx: Optional[tuple[Tensor, Tensor]] = None |
|
) -> tuple[Tensor, Tensor]: |
|
if input.dim() not in (1, 2): |
|
raise ValueError( |
|
f"LSTMCell: Expected input to be 1D or 2D, got {input.dim()}D instead" |
|
) |
|
if hx is not None: |
|
for idx, value in enumerate(hx): |
|
if value.dim() not in (1, 2): |
|
raise ValueError( |
|
f"LSTMCell: Expected hx[{idx}] to be 1D or 2D, got {value.dim()}D instead" |
|
) |
|
is_batched = input.dim() == 2 |
|
if not is_batched: |
|
input = input.unsqueeze(0) |
|
|
|
if hx is None: |
|
zeros = torch.zeros( |
|
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device |
|
) |
|
hx = (zeros, zeros) |
|
else: |
|
hx = (hx[0].unsqueeze(0), hx[1].unsqueeze(0)) if not is_batched else hx |
|
|
|
ret = _VF.lstm_cell( |
|
input, |
|
hx, |
|
self.weight_ih, |
|
self.weight_hh, |
|
self.bias_ih, |
|
self.bias_hh, |
|
) |
|
|
|
if not is_batched: |
|
ret = (ret[0].squeeze(0), ret[1].squeeze(0)) |
|
return ret |
|
|
|
|
|
class GRUCell(RNNCellBase): |
|
r"""A gated recurrent unit (GRU) cell. |
|
|
|
.. math:: |
|
|
|
\begin{array}{ll} |
|
r = \sigma(W_{ir} x + b_{ir} + W_{hr} h + b_{hr}) \\ |
|
z = \sigma(W_{iz} x + b_{iz} + W_{hz} h + b_{hz}) \\ |
|
n = \tanh(W_{in} x + b_{in} + r \odot (W_{hn} h + b_{hn})) \\ |
|
h' = (1 - z) \odot n + z \odot h |
|
\end{array} |
|
|
|
where :math:`\sigma` is the sigmoid function, and :math:`\odot` is the Hadamard product. |
|
|
|
Args: |
|
input_size: The number of expected features in the input `x` |
|
hidden_size: The number of features in the hidden state `h` |
|
bias: If ``False``, then the layer does not use bias weights `b_ih` and |
|
`b_hh`. Default: ``True`` |
|
|
|
Inputs: input, hidden |
|
- **input** : tensor containing input features |
|
- **hidden** : tensor containing the initial hidden |
|
state for each element in the batch. |
|
Defaults to zero if not provided. |
|
|
|
Outputs: h' |
|
- **h'** : tensor containing the next hidden state |
|
for each element in the batch |
|
|
|
Shape: |
|
- input: :math:`(N, H_{in})` or :math:`(H_{in})` tensor containing input features where |
|
:math:`H_{in}` = `input_size`. |
|
- hidden: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the initial hidden |
|
state where :math:`H_{out}` = `hidden_size`. Defaults to zero if not provided. |
|
- output: :math:`(N, H_{out})` or :math:`(H_{out})` tensor containing the next hidden state. |
|
|
|
Attributes: |
|
weight_ih: the learnable input-hidden weights, of shape |
|
`(3*hidden_size, input_size)` |
|
weight_hh: the learnable hidden-hidden weights, of shape |
|
`(3*hidden_size, hidden_size)` |
|
bias_ih: the learnable input-hidden bias, of shape `(3*hidden_size)` |
|
bias_hh: the learnable hidden-hidden bias, of shape `(3*hidden_size)` |
|
|
|
.. note:: |
|
All the weights and biases are initialized from :math:`\mathcal{U}(-\sqrt{k}, \sqrt{k})` |
|
where :math:`k = \frac{1}{\text{hidden\_size}}` |
|
|
|
On certain ROCm devices, when using float16 inputs this module will use :ref:`different precision<fp16_on_mi200>` for backward. |
|
|
|
Examples:: |
|
|
|
>>> rnn = nn.GRUCell(10, 20) |
|
>>> input = torch.randn(6, 3, 10) |
|
>>> hx = torch.randn(3, 20) |
|
>>> output = [] |
|
>>> for i in range(6): |
|
... hx = rnn(input[i], hx) |
|
... output.append(hx) |
|
""" |
|
|
|
def __init__( |
|
self, |
|
input_size: int, |
|
hidden_size: int, |
|
bias: bool = True, |
|
device=None, |
|
dtype=None, |
|
) -> None: |
|
factory_kwargs = {"device": device, "dtype": dtype} |
|
super().__init__(input_size, hidden_size, bias, num_chunks=3, **factory_kwargs) |
|
|
|
def forward(self, input: Tensor, hx: Optional[Tensor] = None) -> Tensor: |
|
if input.dim() not in (1, 2): |
|
raise ValueError( |
|
f"GRUCell: Expected input to be 1D or 2D, got {input.dim()}D instead" |
|
) |
|
if hx is not None and hx.dim() not in (1, 2): |
|
raise ValueError( |
|
f"GRUCell: Expected hidden to be 1D or 2D, got {hx.dim()}D instead" |
|
) |
|
is_batched = input.dim() == 2 |
|
if not is_batched: |
|
input = input.unsqueeze(0) |
|
|
|
if hx is None: |
|
hx = torch.zeros( |
|
input.size(0), self.hidden_size, dtype=input.dtype, device=input.device |
|
) |
|
else: |
|
hx = hx.unsqueeze(0) if not is_batched else hx |
|
|
|
ret = _VF.gru_cell( |
|
input, |
|
hx, |
|
self.weight_ih, |
|
self.weight_hh, |
|
self.bias_ih, |
|
self.bias_hh, |
|
) |
|
|
|
if not is_batched: |
|
ret = ret.squeeze(0) |
|
|
|
return ret |
|
|