|
import fnmatch |
|
import importlib |
|
import inspect |
|
import sys |
|
from dataclasses import dataclass |
|
from enum import Enum |
|
from functools import partial |
|
from inspect import signature |
|
from types import ModuleType |
|
from typing import Any, Callable, Dict, get_args, Iterable, List, Mapping, Optional, Set, Type, TypeVar, Union |
|
|
|
from torch import nn |
|
|
|
from .._internally_replaced_utils import load_state_dict_from_url |
|
|
|
|
|
__all__ = ["WeightsEnum", "Weights", "get_model", "get_model_builder", "get_model_weights", "get_weight", "list_models"] |
|
|
|
|
|
@dataclass |
|
class Weights: |
|
""" |
|
This class is used to group important attributes associated with the pre-trained weights. |
|
|
|
Args: |
|
url (str): The location where we find the weights. |
|
transforms (Callable): A callable that constructs the preprocessing method (or validation preset transforms) |
|
needed to use the model. The reason we attach a constructor method rather than an already constructed |
|
object is because the specific object might have memory and thus we want to delay initialization until |
|
needed. |
|
meta (Dict[str, Any]): Stores meta-data related to the weights of the model and its configuration. These can be |
|
informative attributes (for example the number of parameters/flops, recipe link/methods used in training |
|
etc), configuration parameters (for example the `num_classes`) needed to construct the model or important |
|
meta-data (for example the `classes` of a classification model) needed to use the model. |
|
""" |
|
|
|
url: str |
|
transforms: Callable |
|
meta: Dict[str, Any] |
|
|
|
def __eq__(self, other: Any) -> bool: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not isinstance(other, Weights): |
|
return NotImplemented |
|
|
|
if self.url != other.url: |
|
return False |
|
|
|
if self.meta != other.meta: |
|
return False |
|
|
|
if isinstance(self.transforms, partial) and isinstance(other.transforms, partial): |
|
return ( |
|
self.transforms.func == other.transforms.func |
|
and self.transforms.args == other.transforms.args |
|
and self.transforms.keywords == other.transforms.keywords |
|
) |
|
else: |
|
return self.transforms == other.transforms |
|
|
|
|
|
class WeightsEnum(Enum): |
|
""" |
|
This class is the parent class of all model weights. Each model building method receives an optional `weights` |
|
parameter with its associated pre-trained weights. It inherits from `Enum` and its values should be of type |
|
`Weights`. |
|
|
|
Args: |
|
value (Weights): The data class entry with the weight information. |
|
""" |
|
|
|
@classmethod |
|
def verify(cls, obj: Any) -> Any: |
|
if obj is not None: |
|
if type(obj) is str: |
|
obj = cls[obj.replace(cls.__name__ + ".", "")] |
|
elif not isinstance(obj, cls): |
|
raise TypeError( |
|
f"Invalid Weight class provided; expected {cls.__name__} but received {obj.__class__.__name__}." |
|
) |
|
return obj |
|
|
|
def get_state_dict(self, *args: Any, **kwargs: Any) -> Mapping[str, Any]: |
|
return load_state_dict_from_url(self.url, *args, **kwargs) |
|
|
|
def __repr__(self) -> str: |
|
return f"{self.__class__.__name__}.{self._name_}" |
|
|
|
@property |
|
def url(self): |
|
return self.value.url |
|
|
|
@property |
|
def transforms(self): |
|
return self.value.transforms |
|
|
|
@property |
|
def meta(self): |
|
return self.value.meta |
|
|
|
|
|
def get_weight(name: str) -> WeightsEnum: |
|
""" |
|
Gets the weights enum value by its full name. Example: "ResNet50_Weights.IMAGENET1K_V1" |
|
|
|
Args: |
|
name (str): The name of the weight enum entry. |
|
|
|
Returns: |
|
WeightsEnum: The requested weight enum. |
|
""" |
|
try: |
|
enum_name, value_name = name.split(".") |
|
except ValueError: |
|
raise ValueError(f"Invalid weight name provided: '{name}'.") |
|
|
|
base_module_name = ".".join(sys.modules[__name__].__name__.split(".")[:-1]) |
|
base_module = importlib.import_module(base_module_name) |
|
model_modules = [base_module] + [ |
|
x[1] |
|
for x in inspect.getmembers(base_module, inspect.ismodule) |
|
if x[1].__file__.endswith("__init__.py") |
|
] |
|
|
|
weights_enum = None |
|
for m in model_modules: |
|
potential_class = m.__dict__.get(enum_name, None) |
|
if potential_class is not None and issubclass(potential_class, WeightsEnum): |
|
weights_enum = potential_class |
|
break |
|
|
|
if weights_enum is None: |
|
raise ValueError(f"The weight enum '{enum_name}' for the specific method couldn't be retrieved.") |
|
|
|
return weights_enum[value_name] |
|
|
|
|
|
def get_model_weights(name: Union[Callable, str]) -> Type[WeightsEnum]: |
|
""" |
|
Returns the weights enum class associated to the given model. |
|
|
|
Args: |
|
name (callable or str): The model builder function or the name under which it is registered. |
|
|
|
Returns: |
|
weights_enum (WeightsEnum): The weights enum class associated with the model. |
|
""" |
|
model = get_model_builder(name) if isinstance(name, str) else name |
|
return _get_enum_from_fn(model) |
|
|
|
|
|
def _get_enum_from_fn(fn: Callable) -> Type[WeightsEnum]: |
|
""" |
|
Internal method that gets the weight enum of a specific model builder method. |
|
|
|
Args: |
|
fn (Callable): The builder method used to create the model. |
|
Returns: |
|
WeightsEnum: The requested weight enum. |
|
""" |
|
sig = signature(fn) |
|
if "weights" not in sig.parameters: |
|
raise ValueError("The method is missing the 'weights' argument.") |
|
|
|
ann = sig.parameters["weights"].annotation |
|
weights_enum = None |
|
if isinstance(ann, type) and issubclass(ann, WeightsEnum): |
|
weights_enum = ann |
|
else: |
|
|
|
for t in get_args(ann): |
|
if isinstance(t, type) and issubclass(t, WeightsEnum): |
|
weights_enum = t |
|
break |
|
|
|
if weights_enum is None: |
|
raise ValueError( |
|
"The WeightsEnum class for the specific method couldn't be retrieved. Make sure the typing info is correct." |
|
) |
|
|
|
return weights_enum |
|
|
|
|
|
M = TypeVar("M", bound=nn.Module) |
|
|
|
BUILTIN_MODELS = {} |
|
|
|
|
|
def register_model(name: Optional[str] = None) -> Callable[[Callable[..., M]], Callable[..., M]]: |
|
def wrapper(fn: Callable[..., M]) -> Callable[..., M]: |
|
key = name if name is not None else fn.__name__ |
|
if key in BUILTIN_MODELS: |
|
raise ValueError(f"An entry is already registered under the name '{key}'.") |
|
BUILTIN_MODELS[key] = fn |
|
return fn |
|
|
|
return wrapper |
|
|
|
|
|
def list_models( |
|
module: Optional[ModuleType] = None, |
|
include: Union[Iterable[str], str, None] = None, |
|
exclude: Union[Iterable[str], str, None] = None, |
|
) -> List[str]: |
|
""" |
|
Returns a list with the names of registered models. |
|
|
|
Args: |
|
module (ModuleType, optional): The module from which we want to extract the available models. |
|
include (str or Iterable[str], optional): Filter(s) for including the models from the set of all models. |
|
Filters are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style |
|
wildcards. In case of many filters, the results is the union of individual filters. |
|
exclude (str or Iterable[str], optional): Filter(s) applied after include_filters to remove models. |
|
Filter are passed to `fnmatch <https://docs.python.org/3/library/fnmatch.html>`__ to match Unix shell-style |
|
wildcards. In case of many filters, the results is removal of all the models that match any individual filter. |
|
|
|
Returns: |
|
models (list): A list with the names of available models. |
|
""" |
|
all_models = { |
|
k for k, v in BUILTIN_MODELS.items() if module is None or v.__module__.rsplit(".", 1)[0] == module.__name__ |
|
} |
|
if include: |
|
models: Set[str] = set() |
|
if isinstance(include, str): |
|
include = [include] |
|
for include_filter in include: |
|
models = models | set(fnmatch.filter(all_models, include_filter)) |
|
else: |
|
models = all_models |
|
|
|
if exclude: |
|
if isinstance(exclude, str): |
|
exclude = [exclude] |
|
for exclude_filter in exclude: |
|
models = models - set(fnmatch.filter(all_models, exclude_filter)) |
|
return sorted(models) |
|
|
|
|
|
def get_model_builder(name: str) -> Callable[..., nn.Module]: |
|
""" |
|
Gets the model name and returns the model builder method. |
|
|
|
Args: |
|
name (str): The name under which the model is registered. |
|
|
|
Returns: |
|
fn (Callable): The model builder method. |
|
""" |
|
name = name.lower() |
|
try: |
|
fn = BUILTIN_MODELS[name] |
|
except KeyError: |
|
raise ValueError(f"Unknown model {name}") |
|
return fn |
|
|
|
|
|
def get_model(name: str, **config: Any) -> nn.Module: |
|
""" |
|
Gets the model name and configuration and returns an instantiated model. |
|
|
|
Args: |
|
name (str): The name under which the model is registered. |
|
**config (Any): parameters passed to the model builder method. |
|
|
|
Returns: |
|
model (nn.Module): The initialized model. |
|
""" |
|
fn = get_model_builder(name) |
|
return fn(**config) |
|
|