# Copyright The Lightning AI team. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Iterator from typing import Any, Optional from typing_extensions import override from lightning_fabric.utilities.data import sized_len from pytorch_lightning.utilities.combined_loader import _ITERATOR_RETURN, CombinedLoader from pytorch_lightning.utilities.exceptions import MisconfigurationException def _profile_nothing() -> None: pass class _DataFetcher(Iterator): def __init__(self) -> None: self._combined_loader: Optional[CombinedLoader] = None self.iterator: Optional[Iterator] = None self.fetched: int = 0 self.done: bool = False self.length: Optional[int] = None self._start_profiler = _profile_nothing self._stop_profiler = _profile_nothing @property def combined_loader(self) -> CombinedLoader: if self._combined_loader is None: raise MisconfigurationException( f"`{self.__class__.__name__}` should have been `setup` with a `CombinedLoader`." ) return self._combined_loader def setup(self, combined_loader: CombinedLoader) -> None: self._combined_loader = combined_loader @override def __iter__(self) -> "_DataFetcher": self.iterator = iter(self.combined_loader) self.reset() return self @override def __next__(self) -> _ITERATOR_RETURN: assert self.iterator is not None self._start_profiler() try: batch = next(self.iterator) except StopIteration: self.done = True raise finally: self._stop_profiler() self.fetched += 1 if self.length is not None: self.done = self.fetched >= self.length return batch def reset(self) -> None: self.fetched = 0 # teardown calls `reset()`, and if it happens early, `combined_loader` can still be None if self._combined_loader is not None: self.length = sized_len(self.combined_loader) self.done = self.length == 0 def teardown(self) -> None: self.reset() if self._combined_loader is not None: self._combined_loader.reset() self.iterator = None class _PrefetchDataFetcher(_DataFetcher): """This class is used to control batch fetching flow. Args: prefetch_batches: Number of batches to pre-fetch. Pre-fetching at least 1 batch is necessary to properly track whether a batch is the last one (available with :attr:`self.done`) when the length is not available. The value of this argument is ignored when the length is available. """ def __init__(self, prefetch_batches: int = 1) -> None: super().__init__() if prefetch_batches < 0: raise ValueError("`prefetch_batches` should at least be 0.") self.prefetch_batches = prefetch_batches self.batches: list[Any] = [] @override def __iter__(self) -> "_PrefetchDataFetcher": super().__iter__() if self.length is not None: # ignore pre-fetching, it's not necessary return self # prefetch batches to know when the iterator will be exhausted in advance for _ in range(self.prefetch_batches): try: batch = super().__next__() self.batches.append(batch) except StopIteration: # this would only happen when prefetch_batches > the number of batches available and makes # `__next__` jump directly to the empty iterator case without trying to fetch again break return self @override def __next__(self) -> _ITERATOR_RETURN: if self.batches: # there are pre-fetched batches already from a previous `prefetching` call. # consume one batch = self.batches.pop(0) try: # refill the consumed batch self.batches.append(super().__next__()) except StopIteration: # no more batches to fetch. we are done only if all pre-fetched batches were returned self.done = not self.batches elif not self.done: # this will run only when no pre-fetching was done. batch = super().__next__() else: # the iterator is empty raise StopIteration return batch @override def reset(self) -> None: super().reset() self.batches = [] class _DataLoaderIterDataFetcher(_DataFetcher): """This class is used to return directly the `dataloader_iter` to the ``LightningModule`` training_step for users to implement their own pre-fetching logic. This feature can be activated as follows: Example:: Class MyModel(LightningModule): def training_step(self, dataloader_iter: Iterator) -> None: # it is the user responsibility to fetch and move the batch to the right device. batch, batch_idx, dataloader_idx = next(dataloader_iter) batch = batch.to(self.device) ... """ def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) self._batch: Any = None self._batch_idx: int = 0 self._dataloader_idx: int = 0 @override def __iter__(self) -> "_DataLoaderIterDataFetcher": super().__iter__() self.iterator_wrapper = iter(_DataFetcherWrapper(self)) return self @override def __next__(self) -> Iterator["_DataFetcherWrapper"]: # type: ignore[override] if self.done: raise StopIteration return self.iterator_wrapper @override def reset(self) -> None: super().reset() self._batch = None self._batch_idx = 0 self._dataloader_idx = 0 class _DataFetcherWrapper(Iterator): def __init__(self, data_fetcher: _DataLoaderIterDataFetcher) -> None: self.data_fetcher = data_fetcher @property def done(self) -> bool: return self.data_fetcher.done @property def fetched(self) -> int: return self.data_fetcher.fetched @property def length(self) -> Optional[int]: return self.data_fetcher.length @override def __next__(self) -> _ITERATOR_RETURN: fetcher = self.data_fetcher if fetcher.done: raise StopIteration batch, batch_idx, dataloader_idx = super(_DataLoaderIterDataFetcher, fetcher).__next__() # save the state so the loops can access it fetcher._batch = batch fetcher._batch_idx = batch_idx fetcher._dataloader_idx = dataloader_idx return batch, batch_idx, dataloader_idx