|
|
|
|
|
|
|
|
|
|
|
|
|
"""An alternative to DataLoader using ZMQ. |
|
|
|
This implements MultiLoader, an alternative to DataLoader when torch |
|
is not available. Subprocesses communicate with the loader through |
|
ZMQ, provided for high performance multithreaded queueing. |
|
""" |
|
|
|
import multiprocessing as mp |
|
import os |
|
import pickle |
|
import uuid |
|
import weakref |
|
|
|
import zmq |
|
|
|
the_protocol = pickle.HIGHEST_PROTOCOL |
|
|
|
all_pids = weakref.WeakSet() |
|
|
|
|
|
class EOF: |
|
"""Indicate that a data stream is finished. |
|
|
|
This class is used to signal the end of a data stream in the MultiLoader. |
|
|
|
Args: |
|
**kw: Arbitrary keyword arguments to be set as instance variables. |
|
""" |
|
|
|
def __init__(self, **kw): |
|
"""Initialize the EOF instance with keyword arguments. |
|
|
|
Args: |
|
**kw: Arbitrary keyword arguments to be set as instance variables. |
|
""" |
|
self.__dict__.update(kw) |
|
|
|
|
|
def reader(dataset, sockname, index, num_workers): |
|
"""Read samples from the dataset and send them over the socket. |
|
|
|
This function is run in a separate process to read data from the dataset |
|
and send it to the main process through a ZMQ socket. |
|
|
|
Args: |
|
dataset: The source dataset to read samples from. |
|
sockname (str): The name of the ZMQ socket to send data to. |
|
index (int): The index of this reader process. |
|
num_workers (int): The total number of worker processes. |
|
|
|
Returns: |
|
None |
|
""" |
|
global the_protocol |
|
os.environ["WORKER"] = str(index) |
|
os.environ["NUM_WORKERS"] = str(num_workers) |
|
ctx = zmq.Context.instance() |
|
sock = ctx.socket(zmq.PUSH) |
|
sock.connect(sockname) |
|
for sample in dataset: |
|
data = pickle.dumps(sample, protocol=the_protocol) |
|
sock.send(data) |
|
sock.send(pickle.dumps(EOF(index=index))) |
|
sock.close() |
|
|
|
|
|
class MultiLoader: |
|
"""Alternative to PyTorch DataLoader based on ZMQ. |
|
|
|
This class creates a multi-process data loader using ZMQ for inter-process |
|
communication, providing an alternative to PyTorch's DataLoader. |
|
|
|
Args: |
|
dataset: The source dataset to load data from. |
|
workers (int): Number of worker processes to spawn. Defaults to 4. |
|
verbose (bool): Whether to report progress verbosely. Defaults to False. |
|
nokill (bool): If True, don't kill old processes when restarting. Defaults to False. |
|
prefix (str): Directory prefix for the ZMQ socket. Defaults to "/tmp/_multi-". |
|
""" |
|
|
|
def __init__(self, dataset, workers=4, verbose=False, nokill=False, prefix="/tmp/_multi-"): |
|
"""Initialize the MultiLoader instance. |
|
|
|
Args: |
|
dataset: The source dataset to load data from. |
|
workers (int): Number of worker processes to spawn. Defaults to 4. |
|
verbose (bool): Whether to report progress verbosely. Defaults to False. |
|
nokill (bool): If True, don't kill old processes when restarting. Defaults to False. |
|
prefix (str): Directory prefix for the ZMQ socket. Defaults to "/tmp/_multi-". |
|
""" |
|
self.dataset = dataset |
|
self.workers = workers |
|
self.verbose = verbose |
|
self.pids = [] |
|
self.socket = None |
|
self.ctx = zmq.Context.instance() |
|
self.nokill = nokill |
|
self.prefix = prefix |
|
|
|
def kill(self): |
|
"""Kill all worker processes and close the ZMQ socket.""" |
|
for pid in self.pids: |
|
if pid is None: |
|
continue |
|
print("killing", pid) |
|
pid.kill() |
|
pid.join(1.0) |
|
self.pids = [] |
|
if self.socket is not None: |
|
print("closing", self.socket) |
|
self.socket.close() |
|
self.socket = None |
|
|
|
def __iter__(self): |
|
"""Return an iterator over this dataloader. |
|
|
|
This method sets up the ZMQ socket, spawns worker processes, and yields |
|
samples from the dataset. |
|
|
|
Yields: |
|
Sample: A sample from the dataset. |
|
|
|
Raises: |
|
None |
|
""" |
|
if not self.nokill: |
|
self.kill() |
|
self.sockname = "ipc://" + self.prefix + str(uuid.uuid4()) |
|
self.socket = self.ctx.socket(zmq.PULL) |
|
self.socket.bind(self.sockname) |
|
if self.verbose: |
|
print("#", self.sockname) |
|
self.pids = [None] * self.workers |
|
for index in range(self.workers): |
|
args = (self.dataset, self.sockname, index, self.workers) |
|
self.pids[index] = mp.Process(target=reader, args=args) |
|
all_pids.update(self.pids) |
|
for pid in self.pids: |
|
pid.start() |
|
count = 0 |
|
while self.pids.count(None) < len(self.pids): |
|
data = self.socket.recv() |
|
sample = pickle.loads(data) |
|
if isinstance(sample, EOF): |
|
if self.verbose: |
|
print("# subprocess finished", sample.index) |
|
self.pids[sample.index].join(1.0) |
|
self.pids[sample.index] = None |
|
else: |
|
yield sample |
|
count += 1 |
|
|