|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import functools |
|
import inspect |
|
import os |
|
from collections import OrderedDict |
|
from collections.abc import Generator, Iterable, Sized |
|
from contextlib import contextmanager |
|
from functools import partial |
|
from typing import Any, Callable, Optional, Union |
|
|
|
from lightning_utilities.core.inheritance import get_all_subclasses |
|
from torch.utils.data import BatchSampler, DataLoader, IterableDataset, Sampler |
|
from typing_extensions import TypeGuard |
|
|
|
from lightning_fabric.utilities.enums import LightningEnum |
|
from lightning_fabric.utilities.exceptions import MisconfigurationException |
|
from lightning_fabric.utilities.rank_zero import rank_zero_warn |
|
from lightning_fabric.utilities.seed import pl_worker_init_function |
|
|
|
|
|
class _WrapAttrTag(LightningEnum): |
|
SET = "set" |
|
DEL = "del" |
|
|
|
def __call__(self, *args: Any) -> None: |
|
fn: Union[Callable[[object, str], None], Callable[[object, str, Any], None]] |
|
fn = setattr if self == self.SET else delattr |
|
return fn(*args) |
|
|
|
|
|
def has_iterable_dataset(dataloader: object) -> bool: |
|
return hasattr(dataloader, "dataset") and isinstance(dataloader.dataset, IterableDataset) |
|
|
|
|
|
def sized_len(dataloader: object) -> Optional[int]: |
|
"""Try to get the length of an object, return ``None`` otherwise.""" |
|
try: |
|
|
|
length = len(dataloader) |
|
except (TypeError, NotImplementedError): |
|
length = None |
|
return length |
|
|
|
|
|
def has_len(dataloader: object) -> TypeGuard[Sized]: |
|
"""Checks if a given object has ``__len__`` method implemented.""" |
|
length = sized_len(dataloader) |
|
if length == 0: |
|
rank_zero_warn( |
|
f"`{dataloader.__class__.__name__}` returned 0 length. Please make sure this was your intention." |
|
) |
|
if length is not None and has_iterable_dataset(dataloader): |
|
rank_zero_warn( |
|
"Your `IterableDataset` has `__len__` defined." |
|
" In combination with multi-process data loading (when num_workers > 1)," |
|
" `__len__` could be inaccurate if each worker is not configured independently" |
|
" to avoid having duplicate data." |
|
) |
|
return length is not None |
|
|
|
|
|
def _update_dataloader(dataloader: DataLoader, sampler: Union[Sampler, Iterable]) -> DataLoader: |
|
dl_args, dl_kwargs = _get_dataloader_init_args_and_kwargs(dataloader, sampler) |
|
return _reinstantiate_wrapped_cls(dataloader, *dl_args, **dl_kwargs) |
|
|
|
|
|
def _get_dataloader_init_args_and_kwargs( |
|
dataloader: DataLoader, |
|
sampler: Union[Sampler, Iterable], |
|
) -> tuple[tuple[Any], dict[str, Any]]: |
|
if not isinstance(dataloader, DataLoader): |
|
raise ValueError(f"The dataloader {dataloader} needs to subclass `torch.utils.data.DataLoader`") |
|
|
|
was_wrapped = hasattr(dataloader, "__pl_saved_args") |
|
if was_wrapped: |
|
dl_args = dataloader.__pl_saved_args |
|
dl_kwargs = dataloader.__pl_saved_kwargs |
|
arg_names = dataloader.__pl_saved_arg_names |
|
original_dataset = dataloader.__dataset |
|
else: |
|
|
|
attrs = {k: v for k, v in vars(dataloader).items() if not k.startswith("_")} |
|
|
|
|
|
original_dataset = None |
|
|
|
attrs["multiprocessing_context"] = dataloader.multiprocessing_context |
|
arg_names = () |
|
|
|
|
|
params = dict(inspect.signature(dataloader.__init__).parameters) |
|
has_variadic_kwargs = any(p.kind is p.VAR_KEYWORD for p in params.values()) |
|
if has_variadic_kwargs: |
|
|
|
|
|
if was_wrapped: |
|
|
|
|
|
params.update({ |
|
k: v for k, v in inspect.signature(DataLoader.__init__).parameters.items() if v.default is not v.empty |
|
}) |
|
else: |
|
params.update(inspect.signature(DataLoader.__init__).parameters) |
|
params.pop("self", None) |
|
|
|
if not was_wrapped: |
|
|
|
non_defaults = {name for name, p in params.items() if name in attrs and p.default is not attrs[name]} |
|
|
|
|
|
non_defaults.add("dataset") |
|
|
|
dl_kwargs = {k: v for k, v in attrs.items() if k in non_defaults} |
|
dl_args = () |
|
|
|
dataset = dl_kwargs.get("dataset", original_dataset) |
|
if isinstance(dataset, IterableDataset): |
|
dl_kwargs["batch_sampler"] = None |
|
dl_kwargs["sampler"] = None |
|
else: |
|
dl_kwargs.update(_dataloader_init_kwargs_resolve_sampler(dataloader, sampler)) |
|
|
|
required_args = { |
|
p.name |
|
for p in params.values() |
|
if p.kind in (p.POSITIONAL_ONLY, p.POSITIONAL_OR_KEYWORD) |
|
and p.default is p.empty |
|
and p.name not in dl_kwargs |
|
and p.name not in arg_names |
|
} |
|
|
|
if required_args: |
|
sorted_required_args = sorted(required_args) |
|
dataloader_cls_name = dataloader.__class__.__name__ |
|
missing_args_message = ", ".join(f"`self.{arg_name}`" for arg_name in sorted_required_args) |
|
raise MisconfigurationException( |
|
f"Trying to inject custom `Sampler` into the `{dataloader_cls_name}` instance. " |
|
"This would fail as some of the `__init__` arguments are not available as instance attributes. " |
|
f"The missing attributes are {sorted_required_args}. If you instantiate your `{dataloader_cls_name}` " |
|
"inside a `*_dataloader` hook of your module, we will do this for you." |
|
f" Otherwise, define {missing_args_message} inside your `__init__`." |
|
) |
|
|
|
if not has_variadic_kwargs: |
|
|
|
missing_kwargs = (set(dl_kwargs) | set(arg_names)) - params.keys() |
|
if missing_kwargs: |
|
sorted_missing_kwargs = sorted(missing_kwargs) |
|
dataloader_cls_name = dataloader.__class__.__name__ |
|
raise TypeError( |
|
f"Trying to inject parameters into the `{dataloader_cls_name}` instance. " |
|
"This would fail as it doesn't expose all its attributes in the `__init__` signature. " |
|
f"The missing arguments are {sorted_missing_kwargs}. HINT: If you wrote the `{dataloader_cls_name}` " |
|
"class, add the `__init__` arguments or allow passing `**kwargs`" |
|
) |
|
|
|
return dl_args, dl_kwargs |
|
|
|
|
|
def _dataloader_init_kwargs_resolve_sampler( |
|
dataloader: DataLoader, |
|
sampler: Union[Sampler, Iterable], |
|
) -> dict[str, Any]: |
|
"""This function is used to handle the sampler, batch_sampler arguments associated within a DataLoader for its re- |
|
instantiation.""" |
|
batch_sampler = getattr(dataloader, "batch_sampler") |
|
|
|
if batch_sampler is not None and type(batch_sampler) is not BatchSampler: |
|
batch_sampler_cls = type(batch_sampler) |
|
if hasattr(batch_sampler, "__pl_saved_args"): |
|
|
|
args = batch_sampler.__pl_saved_args |
|
kwargs = batch_sampler.__pl_saved_kwargs |
|
default_kwargs = batch_sampler.__pl_saved_default_kwargs |
|
arg_names = batch_sampler.__pl_saved_arg_names |
|
|
|
success, args, kwargs = _replace_value_in_saved_args( |
|
"sampler", sampler, args, kwargs, default_kwargs, arg_names |
|
) |
|
if not success: |
|
raise TypeError( |
|
"Trying to inject a modified sampler into the batch sampler; however, it seems the class " |
|
f"`{batch_sampler_cls.__qualname__}` does not have an argument called `sampler.` To mitigate " |
|
"this, expose an argument `sampler` in the `__init__` method of your custom class." |
|
) |
|
|
|
batch_sampler = _reinstantiate_wrapped_cls(batch_sampler, *args, **kwargs) |
|
elif hasattr(batch_sampler, "batch_size") and hasattr(batch_sampler, "drop_last"): |
|
|
|
|
|
try: |
|
batch_sampler = batch_sampler_cls( |
|
sampler, |
|
batch_size=batch_sampler.batch_size, |
|
drop_last=batch_sampler.drop_last, |
|
) |
|
except TypeError as ex: |
|
import re |
|
|
|
match = re.match(r".*__init__\(\) (got multiple values)|(missing \d required)", str(ex)) |
|
if not match: |
|
|
|
raise |
|
|
|
|
|
|
|
raise TypeError( |
|
" Lightning can't inject a (distributed) sampler into your batch sampler, because it doesn't" |
|
" subclass PyTorch's `BatchSampler`. To mitigate this, either follow the API of `BatchSampler`" |
|
" or set`.setup_dataloaders(..., use_distributed_sampler=False)`. If you choose the latter, you" |
|
" will be responsible for handling the distributed sampling within your batch sampler." |
|
) from ex |
|
else: |
|
|
|
raise TypeError( |
|
" Lightning can't inject a (distributed) sampler into your batch sampler, because it doesn't" |
|
" subclass PyTorch's `BatchSampler`. To mitigate this, either follow the API of `BatchSampler`" |
|
" or set`.setup_dataloaders(..., use_distributed_sampler=False)`. If you choose the latter, you" |
|
" will be responsible for handling the distributed sampling within your batch sampler." |
|
) |
|
|
|
return { |
|
"sampler": None, |
|
"shuffle": False, |
|
"batch_sampler": batch_sampler, |
|
"batch_size": 1, |
|
"drop_last": False, |
|
} |
|
|
|
return {"sampler": sampler, "shuffle": False, "batch_sampler": None} |
|
|
|
|
|
def _auto_add_worker_init_fn(dataloader: object, rank: int) -> None: |
|
if not hasattr(dataloader, "worker_init_fn"): |
|
return |
|
if int(os.environ.get("PL_SEED_WORKERS", 0)) and dataloader.worker_init_fn is None: |
|
dataloader.worker_init_fn = partial(pl_worker_init_function, rank=rank) |
|
|
|
|
|
def _reinstantiate_wrapped_cls(orig_object: Any, *args: Any, explicit_cls: Optional[type] = None, **kwargs: Any) -> Any: |
|
constructor = type(orig_object) if explicit_cls is None else explicit_cls |
|
|
|
try: |
|
result = constructor(*args, **kwargs) |
|
except TypeError as ex: |
|
|
|
|
|
import re |
|
|
|
match = re.match(r".*__init__\(\) got multiple values .* '(\w+)'", str(ex)) |
|
if not match: |
|
|
|
raise |
|
argument = match.groups()[0] |
|
message = ( |
|
f"The {constructor.__name__} implementation has an error where more than one `__init__` argument" |
|
f" can be passed to its parent's `{argument}=...` `__init__` argument. This is likely caused by allowing" |
|
f" passing both a custom argument that will map to the `{argument}` argument as well as `**kwargs`." |
|
f" `kwargs` should be filtered to make sure they don't contain the `{argument}` key." |
|
" This argument was automatically passed to your object by PyTorch Lightning." |
|
) |
|
raise MisconfigurationException(message) from ex |
|
|
|
attrs_record = getattr(orig_object, "__pl_attrs_record", []) |
|
for args, fn in attrs_record: |
|
fn(result, *args) |
|
|
|
return result |
|
|
|
|
|
def _wrap_init_method(init: Callable, store_explicit_arg: Optional[str] = None) -> Callable: |
|
"""Wraps the ``__init__`` method of classes (currently :class:`~torch.utils.data.DataLoader` and |
|
:class:`~torch.utils.data.BatchSampler`) in order to enable re-instantiation of custom subclasses.""" |
|
|
|
@functools.wraps(init) |
|
def wrapper(obj: Any, *args: Any, **kwargs: Any) -> None: |
|
|
|
|
|
old_inside_init = getattr(obj, "__pl_inside_init", False) |
|
object.__setattr__(obj, "__pl_inside_init", True) |
|
params = inspect.signature(init).parameters |
|
|
|
parameters_defaults = OrderedDict( |
|
(param.name, param.default) |
|
for param in params.values() |
|
if param.name != "self" and param.kind not in (param.VAR_POSITIONAL, param.VAR_KEYWORD) |
|
) |
|
|
|
param_names = tuple(parameters_defaults)[: len(args)] |
|
|
|
default_kwargs = { |
|
name: value |
|
for name, value in parameters_defaults.items() |
|
if name not in kwargs and name not in param_names and value != inspect.Parameter.empty |
|
} |
|
|
|
if not hasattr(obj, "__pl_saved_args"): |
|
object.__setattr__(obj, "__pl_saved_args", args) |
|
object.__setattr__(obj, "__pl_saved_kwargs", kwargs) |
|
object.__setattr__(obj, "__pl_saved_arg_names", param_names) |
|
object.__setattr__(obj, "__pl_saved_default_kwargs", default_kwargs) |
|
|
|
|
|
|
|
|
|
if store_explicit_arg is not None: |
|
if store_explicit_arg in param_names: |
|
object.__setattr__(obj, f"__{store_explicit_arg}", args[param_names.index(store_explicit_arg)]) |
|
elif store_explicit_arg in kwargs: |
|
object.__setattr__(obj, f"__{store_explicit_arg}", kwargs[store_explicit_arg]) |
|
|
|
init(obj, *args, **kwargs) |
|
object.__setattr__(obj, "__pl_inside_init", old_inside_init) |
|
|
|
return wrapper |
|
|
|
|
|
def _wrap_attr_method(method: Callable, tag: _WrapAttrTag) -> Callable: |
|
"""Wraps the ``__setattr__`` or ``__delattr__`` method of classes (currently :class:`~torch.utils.data.DataLoader` |
|
and :class:`~torch.utils.data.BatchSampler`) in order to enable re- instantiation of custom subclasses.""" |
|
|
|
@functools.wraps(method) |
|
def wrapper(obj: Any, *args: Any) -> None: |
|
|
|
name, *_ = args |
|
prev_call_name, prev_call_method = getattr(obj, "__pl_current_call", (None, "method")) |
|
first_call = not (prev_call_name == name and prev_call_method == tag) |
|
|
|
|
|
object.__setattr__(obj, "__pl_current_call", (name, tag)) |
|
|
|
|
|
method(obj, *args) |
|
if first_call and not getattr(obj, "__pl_inside_init", True): |
|
|
|
|
|
attrs_record = getattr(obj, "__pl_attrs_record", []) |
|
attrs_record.append((args, tag)) |
|
object.__setattr__(obj, "__pl_attrs_record", attrs_record) |
|
object.__setattr__(obj, "__pl_current_call", (prev_call_name, prev_call_method)) |
|
|
|
return wrapper |
|
|
|
|
|
@contextmanager |
|
def _replace_dunder_methods(base_cls: type, store_explicit_arg: Optional[str] = None) -> Generator[None, None, None]: |
|
"""This context manager is used to add support for re-instantiation of custom (subclasses) of `base_cls`. |
|
|
|
It patches the ``__init__``, ``__setattr__`` and ``__delattr__`` methods. |
|
|
|
""" |
|
classes = get_all_subclasses(base_cls) | {base_cls} |
|
for cls in classes: |
|
|
|
|
|
if "__init__" in cls.__dict__: |
|
cls.__old__init__ = cls.__init__ |
|
cls.__init__ = _wrap_init_method(cls.__init__, store_explicit_arg) |
|
|
|
|
|
|
|
for patch_fn_name, tag in (("__setattr__", _WrapAttrTag.SET), ("__delattr__", _WrapAttrTag.DEL)): |
|
if patch_fn_name in cls.__dict__ or cls is base_cls: |
|
saved_name = f"__old{patch_fn_name}" |
|
setattr(cls, saved_name, getattr(cls, patch_fn_name)) |
|
setattr(cls, patch_fn_name, _wrap_attr_method(getattr(cls, patch_fn_name), tag)) |
|
yield |
|
for cls in classes: |
|
for patched_name in ("__setattr__", "__delattr__", "__init__"): |
|
|
|
|
|
if f"__old{patched_name}" in cls.__dict__: |
|
setattr(cls, patched_name, getattr(cls, f"__old{patched_name}")) |
|
delattr(cls, f"__old{patched_name}") |
|
|
|
|
|
def _replace_value_in_saved_args( |
|
replace_key: str, |
|
replace_value: Any, |
|
args: tuple[Any, ...], |
|
kwargs: dict[str, Any], |
|
default_kwargs: dict[str, Any], |
|
arg_names: tuple[str, ...], |
|
) -> tuple[bool, tuple[Any, ...], dict[str, Any]]: |
|
"""Tries to replace an argument value in a saved list of args and kwargs. |
|
|
|
Returns a tuple indicating success of the operation and modified saved args and kwargs |
|
|
|
""" |
|
|
|
if replace_key in arg_names: |
|
replace_index = arg_names.index(replace_key) |
|
args = args[:replace_index] + (replace_value,) + args[replace_index + 1 :] |
|
return True, args, kwargs |
|
if replace_key in kwargs or replace_key in default_kwargs: |
|
kwargs[replace_key] = replace_value |
|
return True, args, kwargs |
|
|
|
return False, args, kwargs |
|
|
|
|
|
def _set_sampler_epoch(dataloader: object, epoch: int) -> None: |
|
"""Calls the ``set_epoch`` method on either the sampler of the given dataloader. |
|
|
|
Every PyTorch dataloader has either a sampler or a batch sampler. If the sampler is wrapped by a |
|
:class:`~torch.utils.data.distributed.DistributedSampler`, ``set_epoch`` must be called at the beginning |
|
of every epoch to ensure shuffling applies a new ordering. This has no effect if shuffling is off. |
|
|
|
""" |
|
|
|
objects: dict[int, Any] = {} |
|
|
|
if (sampler := getattr(dataloader, "sampler", None)) is not None: |
|
objects[id(sampler)] = sampler |
|
|
|
if (batch_sampler := getattr(dataloader, "batch_sampler", None)) is not None and ( |
|
sampler := getattr(batch_sampler, "sampler", None) |
|
) is not None: |
|
objects[id(sampler)] = sampler |
|
for obj in objects.values(): |
|
set_epoch = getattr(obj, "set_epoch", None) |
|
if callable(set_epoch): |
|
set_epoch(epoch) |
|
|
|
|
|
def suggested_max_num_workers(local_world_size: int) -> int: |
|
"""Suggests an upper bound of ``num_workers`` to use in a PyTorch :class:`~torch.utils.data.DataLoader` based on |
|
the number of CPU cores available on the system and the number of distributed processes in the current machine. |
|
|
|
Args: |
|
local_world_size: The number of distributed processes running on the current machine. Set this to the number |
|
of devices configured in Fabric/Trainer. |
|
|
|
""" |
|
if local_world_size < 1: |
|
raise ValueError(f"`local_world_size` should be >= 1, got {local_world_size}.") |
|
cpu_count = _num_cpus_available() |
|
return max(1, cpu_count // local_world_size - 1) |
|
|
|
|
|
def _num_cpus_available() -> int: |
|
if hasattr(os, "sched_getaffinity"): |
|
return len(os.sched_getaffinity(0)) |
|
|
|
cpu_count = os.cpu_count() |
|
return 1 if cpu_count is None else cpu_count |
|
|
|
|
|
class AttributeDict(dict): |
|
"""A container to store state variables of your program. |
|
|
|
This is a drop-in replacement for a Python dictionary, with the additional functionality to access and modify keys |
|
through attribute lookup for convenience. |
|
|
|
Use this to define the state of your program, then pass it to |
|
:meth:`~lightning_fabric.fabric.Fabric.save` and :meth:`~lightning_fabric.fabric.Fabric.load`. |
|
|
|
Example: |
|
>>> import torch |
|
>>> model = torch.nn.Linear(2, 2) |
|
>>> state = AttributeDict(model=model, iter_num=0) |
|
>>> state.model |
|
Linear(in_features=2, out_features=2, bias=True) |
|
>>> state.iter_num += 1 |
|
>>> state.iter_num |
|
1 |
|
>>> state |
|
"iter_num": 1 |
|
"model": Linear(in_features=2, out_features=2, bias=True) |
|
|
|
""" |
|
|
|
def __getattr__(self, key: str) -> Any: |
|
try: |
|
return self[key] |
|
except KeyError as e: |
|
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{key}'") from e |
|
|
|
def __setattr__(self, key: str, val: Any) -> None: |
|
self[key] = val |
|
|
|
def __delattr__(self, item: str) -> None: |
|
if item not in self: |
|
raise KeyError(item) |
|
del self[item] |
|
|
|
def __repr__(self) -> str: |
|
if not len(self): |
|
return "" |
|
max_key_length = max(len(str(k)) for k in self) |
|
tmp_name = "{:" + str(max_key_length + 3) + "s} {}" |
|
rows = [tmp_name.format(f'"{n}":', self[n]) for n in sorted(self.keys())] |
|
return "\n".join(rows) |
|
|