|
|
|
import inspect |
|
|
|
import torch |
|
|
|
|
|
def skip_init(module_cls, *args, **kwargs): |
|
r""" |
|
Given a module class object and args / kwargs, instantiate the module without initializing parameters / buffers. |
|
|
|
This can be useful if initialization is slow or if custom initialization will |
|
be performed, making the default initialization unnecessary. There are some caveats to this, due to |
|
the way this function is implemented: |
|
|
|
1. The module must accept a `device` arg in its constructor that is passed to any parameters |
|
or buffers created during construction. |
|
|
|
2. The module must not perform any computation on parameters in its constructor except |
|
initialization (i.e. functions from :mod:`torch.nn.init`). |
|
|
|
If these conditions are satisfied, the module can be instantiated with parameter / buffer values |
|
uninitialized, as if having been created using :func:`torch.empty`. |
|
|
|
Args: |
|
module_cls: Class object; should be a subclass of :class:`torch.nn.Module` |
|
args: args to pass to the module's constructor |
|
kwargs: kwargs to pass to the module's constructor |
|
|
|
Returns: |
|
Instantiated module with uninitialized parameters / buffers |
|
|
|
Example:: |
|
|
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic") |
|
>>> import torch |
|
>>> m = torch.nn.utils.skip_init(torch.nn.Linear, 5, 1) |
|
>>> m.weight |
|
Parameter containing: |
|
tensor([[0.0000e+00, 1.5846e+29, 7.8307e+00, 2.5250e-29, 1.1210e-44]], |
|
requires_grad=True) |
|
>>> m2 = torch.nn.utils.skip_init(torch.nn.Linear, in_features=6, out_features=1) |
|
>>> m2.weight |
|
Parameter containing: |
|
tensor([[-1.4677e+24, 4.5915e-41, 1.4013e-45, 0.0000e+00, -1.4677e+24, |
|
4.5915e-41]], requires_grad=True) |
|
|
|
""" |
|
if not issubclass(module_cls, torch.nn.Module): |
|
raise RuntimeError(f"Expected a Module; got {module_cls}") |
|
if "device" not in inspect.signature(module_cls).parameters: |
|
raise RuntimeError("Module must support a 'device' arg to skip initialization") |
|
|
|
final_device = kwargs.pop("device", "cpu") |
|
kwargs["device"] = "meta" |
|
return module_cls(*args, **kwargs).to_empty(device=final_device) |
|
|