|
|
|
|
|
|
|
|
|
|
|
from torch.utils.data import DataLoader, RandomSampler |
|
from torch.utils.data.distributed import DistributedSampler |
|
from deepspeed.accelerator import get_accelerator |
|
|
|
from deepspeed.runtime.data_pipeline.data_sampling.data_sampler import DeepSpeedDataSampler |
|
from deepspeed.runtime.data_pipeline.constants import CURRICULUM_LEARNING, \ |
|
DATA_EFFICIENCY, DATA_SAMPLING_NUM_WORKERS |
|
from deepspeed.runtime.constants import GRADIENT_ACCUMULATION_STEPS, \ |
|
DATA_PARALLEL_GROUP, GLOBAL_RANK |
|
|
|
|
|
class RepeatingLoader: |
|
|
|
def __init__(self, loader): |
|
"""Wraps an iterator to allow for infinite iteration. This is especially useful |
|
for DataLoader types that we wish to automatically restart upon completion. |
|
|
|
Args: |
|
loader (iterator): The data loader to repeat. |
|
""" |
|
self.loader = loader |
|
self.data_iter = iter(self.loader) |
|
|
|
def __iter__(self): |
|
return self |
|
|
|
def __next__(self): |
|
try: |
|
batch = next(self.data_iter) |
|
except StopIteration: |
|
self.data_iter = iter(self.loader) |
|
batch = next(self.data_iter) |
|
return batch |
|
|
|
|
|
class DeepSpeedDataLoader(object): |
|
|
|
def __init__(self, |
|
dataset, |
|
batch_size, |
|
pin_memory, |
|
local_rank, |
|
tput_timer, |
|
collate_fn=None, |
|
num_local_io_workers=None, |
|
data_sampler=None, |
|
data_parallel_world_size=None, |
|
data_parallel_rank=None, |
|
dataloader_drop_last=False, |
|
deepspeed_dataloader_config={}): |
|
self.deepspeed_dataloader_config = deepspeed_dataloader_config |
|
self.tput_timer = tput_timer |
|
self.batch_size = batch_size |
|
self.curriculum_learning_enabled = False |
|
if CURRICULUM_LEARNING in deepspeed_dataloader_config: |
|
self.curriculum_learning_enabled = deepspeed_dataloader_config[CURRICULUM_LEARNING] |
|
|
|
if self.curriculum_learning_enabled: |
|
data_sampler = DeepSpeedDataSampler(self.deepspeed_dataloader_config[DATA_EFFICIENCY], |
|
len(dataset), |
|
self.batch_size, |
|
data_parallel_rank, |
|
data_parallel_world_size, |
|
self.deepspeed_dataloader_config[DATA_PARALLEL_GROUP], |
|
self.deepspeed_dataloader_config[GRADIENT_ACCUMULATION_STEPS], |
|
self.deepspeed_dataloader_config[GLOBAL_RANK], |
|
drop_last=dataloader_drop_last) |
|
device_count = get_accelerator().device_count() |
|
num_local_io_workers = self.deepspeed_dataloader_config[DATA_SAMPLING_NUM_WORKERS] |
|
else: |
|
if local_rank >= 0: |
|
if data_sampler is None: |
|
data_sampler = DistributedSampler(dataset=dataset, |
|
num_replicas=data_parallel_world_size, |
|
rank=data_parallel_rank) |
|
device_count = 1 |
|
else: |
|
if data_sampler is None: |
|
data_sampler = RandomSampler(dataset) |
|
device_count = get_accelerator().device_count() |
|
batch_size *= device_count |
|
|
|
if num_local_io_workers is None: |
|
num_local_io_workers = 2 * device_count |
|
|
|
self.num_local_io_workers = num_local_io_workers |
|
self.data_sampler = data_sampler |
|
self.dataset = dataset |
|
self.collate_fn = collate_fn |
|
self.device_count = device_count |
|
self.batch_size = batch_size |
|
self.pin_memory = pin_memory |
|
self.data = None |
|
self.dataloader_drop_last = dataloader_drop_last |
|
self.post_process_func = None |
|
|
|
if self.dataloader_drop_last: |
|
self.len = len(self.data_sampler) // self.batch_size |
|
else: |
|
from math import ceil |
|
self.len = ceil(len(self.data_sampler) / self.batch_size) |
|
|
|
def __iter__(self): |
|
self._create_dataloader() |
|
return self |
|
|
|
def __len__(self): |
|
return self.len |
|
|
|
def __next__(self): |
|
if self.tput_timer: |
|
self.tput_timer.start() |
|
if self.curriculum_learning_enabled: |
|
data = next(self.data_iterator) |
|
if self.post_process_func is not None: |
|
data = self.post_process_func(data, self.data_sampler.state_dict()) |
|
return data |
|
else: |
|
return next(self.data) |
|
|
|
def _create_dataloader(self): |
|
if self.curriculum_learning_enabled: |
|
if self.collate_fn is None: |
|
self.dataloader = DataLoader(self.dataset, |
|
pin_memory=self.pin_memory, |
|
batch_sampler=self.data_sampler, |
|
num_workers=self.num_local_io_workers) |
|
else: |
|
self.dataloader = DataLoader(self.dataset, |
|
pin_memory=self.pin_memory, |
|
batch_sampler=self.data_sampler, |
|
collate_fn=self.collate_fn, |
|
num_workers=self.num_local_io_workers) |
|
self.data_iterator = iter(self.dataloader) |
|
return self.dataloader |
|
else: |
|
if self.collate_fn is None: |
|
self.dataloader = DataLoader(self.dataset, |
|
batch_size=self.batch_size, |
|
pin_memory=self.pin_memory, |
|
sampler=self.data_sampler, |
|
num_workers=self.num_local_io_workers, |
|
drop_last=self.dataloader_drop_last) |
|
else: |
|
self.dataloader = DataLoader(self.dataset, |
|
batch_size=self.batch_size, |
|
pin_memory=self.pin_memory, |
|
sampler=self.data_sampler, |
|
collate_fn=self.collate_fn, |
|
num_workers=self.num_local_io_workers, |
|
drop_last=self.dataloader_drop_last) |
|
self.data = (x for x in self.dataloader) |
|
|
|
return self.dataloader |
|
|
|
|
|
|
|
|