|
import functools |
|
import inspect |
|
import warnings |
|
from collections import OrderedDict |
|
from typing import Any, Callable, Dict, Optional, Tuple, TypeVar, Union |
|
|
|
from torch import nn |
|
|
|
from .._utils import sequence_to_str |
|
from ._api import WeightsEnum |
|
|
|
|
|
class IntermediateLayerGetter(nn.ModuleDict): |
|
""" |
|
Module wrapper that returns intermediate layers from a model |
|
|
|
It has a strong assumption that the modules have been registered |
|
into the model in the same order as they are used. |
|
This means that one should **not** reuse the same nn.Module |
|
twice in the forward if you want this to work. |
|
|
|
Additionally, it is only able to query submodules that are directly |
|
assigned to the model. So if `model` is passed, `model.feature1` can |
|
be returned, but not `model.feature1.layer2`. |
|
|
|
Args: |
|
model (nn.Module): model on which we will extract the features |
|
return_layers (Dict[name, new_name]): a dict containing the names |
|
of the modules for which the activations will be returned as |
|
the key of the dict, and the value of the dict is the name |
|
of the returned activation (which the user can specify). |
|
|
|
Examples:: |
|
|
|
>>> m = torchvision.models.resnet18(weights=ResNet18_Weights.DEFAULT) |
|
>>> # extract layer1 and layer3, giving as names `feat1` and feat2` |
|
>>> new_m = torchvision.models._utils.IntermediateLayerGetter(m, |
|
>>> {'layer1': 'feat1', 'layer3': 'feat2'}) |
|
>>> out = new_m(torch.rand(1, 3, 224, 224)) |
|
>>> print([(k, v.shape) for k, v in out.items()]) |
|
>>> [('feat1', torch.Size([1, 64, 56, 56])), |
|
>>> ('feat2', torch.Size([1, 256, 14, 14]))] |
|
""" |
|
|
|
_version = 2 |
|
__annotations__ = { |
|
"return_layers": Dict[str, str], |
|
} |
|
|
|
def __init__(self, model: nn.Module, return_layers: Dict[str, str]) -> None: |
|
if not set(return_layers).issubset([name for name, _ in model.named_children()]): |
|
raise ValueError("return_layers are not present in model") |
|
orig_return_layers = return_layers |
|
return_layers = {str(k): str(v) for k, v in return_layers.items()} |
|
layers = OrderedDict() |
|
for name, module in model.named_children(): |
|
layers[name] = module |
|
if name in return_layers: |
|
del return_layers[name] |
|
if not return_layers: |
|
break |
|
|
|
super().__init__(layers) |
|
self.return_layers = orig_return_layers |
|
|
|
def forward(self, x): |
|
out = OrderedDict() |
|
for name, module in self.items(): |
|
x = module(x) |
|
if name in self.return_layers: |
|
out_name = self.return_layers[name] |
|
out[out_name] = x |
|
return out |
|
|
|
|
|
def _make_divisible(v: float, divisor: int, min_value: Optional[int] = None) -> int: |
|
""" |
|
This function is taken from the original tf repo. |
|
It ensures that all layers have a channel number that is divisible by 8 |
|
It can be seen here: |
|
https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet/mobilenet.py |
|
""" |
|
if min_value is None: |
|
min_value = divisor |
|
new_v = max(min_value, int(v + divisor / 2) // divisor * divisor) |
|
|
|
if new_v < 0.9 * v: |
|
new_v += divisor |
|
return new_v |
|
|
|
|
|
D = TypeVar("D") |
|
|
|
|
|
def kwonly_to_pos_or_kw(fn: Callable[..., D]) -> Callable[..., D]: |
|
"""Decorates a function that uses keyword only parameters to also allow them being passed as positionals. |
|
|
|
For example, consider the use case of changing the signature of ``old_fn`` into the one from ``new_fn``: |
|
|
|
.. code:: |
|
|
|
def old_fn(foo, bar, baz=None): |
|
... |
|
|
|
def new_fn(foo, *, bar, baz=None): |
|
... |
|
|
|
Calling ``old_fn("foo", "bar, "baz")`` was valid, but the same call is no longer valid with ``new_fn``. To keep BC |
|
and at the same time warn the user of the deprecation, this decorator can be used: |
|
|
|
.. code:: |
|
|
|
@kwonly_to_pos_or_kw |
|
def new_fn(foo, *, bar, baz=None): |
|
... |
|
|
|
new_fn("foo", "bar, "baz") |
|
""" |
|
params = inspect.signature(fn).parameters |
|
|
|
try: |
|
keyword_only_start_idx = next( |
|
idx for idx, param in enumerate(params.values()) if param.kind == param.KEYWORD_ONLY |
|
) |
|
except StopIteration: |
|
raise TypeError(f"Found no keyword-only parameter on function '{fn.__name__}'") from None |
|
|
|
keyword_only_params = tuple(inspect.signature(fn).parameters)[keyword_only_start_idx:] |
|
|
|
@functools.wraps(fn) |
|
def wrapper(*args: Any, **kwargs: Any) -> D: |
|
args, keyword_only_args = args[:keyword_only_start_idx], args[keyword_only_start_idx:] |
|
if keyword_only_args: |
|
keyword_only_kwargs = dict(zip(keyword_only_params, keyword_only_args)) |
|
warnings.warn( |
|
f"Using {sequence_to_str(tuple(keyword_only_kwargs.keys()), separate_last='and ')} as positional " |
|
f"parameter(s) is deprecated since 0.13 and may be removed in the future. Please use keyword parameter(s) " |
|
f"instead." |
|
) |
|
kwargs.update(keyword_only_kwargs) |
|
|
|
return fn(*args, **kwargs) |
|
|
|
return wrapper |
|
|
|
|
|
W = TypeVar("W", bound=WeightsEnum) |
|
M = TypeVar("M", bound=nn.Module) |
|
V = TypeVar("V") |
|
|
|
|
|
def handle_legacy_interface(**weights: Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): |
|
"""Decorates a model builder with the new interface to make it compatible with the old. |
|
|
|
In particular this handles two things: |
|
|
|
1. Allows positional parameters again, but emits a deprecation warning in case they are used. See |
|
:func:`torchvision.prototype.utils._internal.kwonly_to_pos_or_kw` for details. |
|
2. Handles the default value change from ``pretrained=False`` to ``weights=None`` and ``pretrained=True`` to |
|
``weights=Weights`` and emits a deprecation warning with instructions for the new interface. |
|
|
|
Args: |
|
**weights (Tuple[str, Union[Optional[W], Callable[[Dict[str, Any]], Optional[W]]]]): Deprecated parameter |
|
name and default value for the legacy ``pretrained=True``. The default value can be a callable in which |
|
case it will be called with a dictionary of the keyword arguments. The only key that is guaranteed to be in |
|
the dictionary is the deprecated parameter name passed as first element in the tuple. All other parameters |
|
should be accessed with :meth:`~dict.get`. |
|
""" |
|
|
|
def outer_wrapper(builder: Callable[..., M]) -> Callable[..., M]: |
|
@kwonly_to_pos_or_kw |
|
@functools.wraps(builder) |
|
def inner_wrapper(*args: Any, **kwargs: Any) -> M: |
|
for weights_param, (pretrained_param, default) in weights.items(): |
|
|
|
|
|
|
|
sentinel = object() |
|
weights_arg = kwargs.get(weights_param, sentinel) |
|
if ( |
|
(weights_param not in kwargs and pretrained_param not in kwargs) |
|
or isinstance(weights_arg, WeightsEnum) |
|
or (isinstance(weights_arg, str) and weights_arg != "legacy") |
|
or weights_arg is None |
|
): |
|
continue |
|
|
|
|
|
|
|
|
|
|
|
pretrained_positional = weights_arg is not sentinel |
|
if pretrained_positional: |
|
|
|
|
|
kwargs[pretrained_param] = pretrained_arg = kwargs.pop(weights_param) |
|
else: |
|
pretrained_arg = kwargs[pretrained_param] |
|
|
|
if pretrained_arg: |
|
default_weights_arg = default(kwargs) if callable(default) else default |
|
if not isinstance(default_weights_arg, WeightsEnum): |
|
raise ValueError(f"No weights available for model {builder.__name__}") |
|
else: |
|
default_weights_arg = None |
|
|
|
if not pretrained_positional: |
|
warnings.warn( |
|
f"The parameter '{pretrained_param}' is deprecated since 0.13 and may be removed in the future, " |
|
f"please use '{weights_param}' instead." |
|
) |
|
|
|
msg = ( |
|
f"Arguments other than a weight enum or `None` for '{weights_param}' are deprecated since 0.13 and " |
|
f"may be removed in the future. " |
|
f"The current behavior is equivalent to passing `{weights_param}={default_weights_arg}`." |
|
) |
|
if pretrained_arg: |
|
msg = ( |
|
f"{msg} You can also use `{weights_param}={type(default_weights_arg).__name__}.DEFAULT` " |
|
f"to get the most up-to-date weights." |
|
) |
|
warnings.warn(msg) |
|
|
|
del kwargs[pretrained_param] |
|
kwargs[weights_param] = default_weights_arg |
|
|
|
return builder(*args, **kwargs) |
|
|
|
return inner_wrapper |
|
|
|
return outer_wrapper |
|
|
|
|
|
def _ovewrite_named_param(kwargs: Dict[str, Any], param: str, new_value: V) -> None: |
|
if param in kwargs: |
|
if kwargs[param] != new_value: |
|
raise ValueError(f"The parameter '{param}' expected value {new_value} but got {kwargs[param]} instead.") |
|
else: |
|
kwargs[param] = new_value |
|
|
|
|
|
def _ovewrite_value_param(param: str, actual: Optional[V], expected: V) -> V: |
|
if actual is not None: |
|
if actual != expected: |
|
raise ValueError(f"The parameter '{param}' expected value {expected} but got {actual} instead.") |
|
return expected |
|
|
|
|
|
class _ModelURLs(dict): |
|
def __getitem__(self, item): |
|
warnings.warn( |
|
"Accessing the model URLs via the internal dictionary of the module is deprecated since 0.13 and may " |
|
"be removed in the future. Please access them via the appropriate Weights Enum instead." |
|
) |
|
return super().__getitem__(item) |
|
|