File size: 5,185 Bytes
9c6594c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
#
# Copyright (c) 2017-2021 NVIDIA CORPORATION. All rights reserved.
# This file is part of the WebDataset library.
# See the LICENSE file for licensing terms (BSD-style).
#

"""An alternative to DataLoader using ZMQ.

This implements MultiLoader, an alternative to DataLoader when torch
is not available. Subprocesses communicate with the loader through
ZMQ, provided for high performance multithreaded queueing.
"""

import multiprocessing as mp
import os
import pickle
import uuid
import weakref

import zmq

the_protocol = pickle.HIGHEST_PROTOCOL

all_pids = weakref.WeakSet()


class EOF:
    """Indicate that a data stream is finished.

    This class is used to signal the end of a data stream in the MultiLoader.

    Args:
        **kw: Arbitrary keyword arguments to be set as instance variables.
    """

    def __init__(self, **kw):
        """Initialize the EOF instance with keyword arguments.

        Args:
            **kw: Arbitrary keyword arguments to be set as instance variables.
        """
        self.__dict__.update(kw)


def reader(dataset, sockname, index, num_workers):
    """Read samples from the dataset and send them over the socket.

    This function is run in a separate process to read data from the dataset
    and send it to the main process through a ZMQ socket.

    Args:
        dataset: The source dataset to read samples from.
        sockname (str): The name of the ZMQ socket to send data to.
        index (int): The index of this reader process.
        num_workers (int): The total number of worker processes.

    Returns:
        None
    """
    global the_protocol
    os.environ["WORKER"] = str(index)
    os.environ["NUM_WORKERS"] = str(num_workers)
    ctx = zmq.Context.instance()
    sock = ctx.socket(zmq.PUSH)
    sock.connect(sockname)
    for sample in dataset:
        data = pickle.dumps(sample, protocol=the_protocol)
        sock.send(data)
    sock.send(pickle.dumps(EOF(index=index)))
    sock.close()


class MultiLoader:
    """Alternative to PyTorch DataLoader based on ZMQ.

    This class creates a multi-process data loader using ZMQ for inter-process
    communication, providing an alternative to PyTorch's DataLoader.

    Args:
        dataset: The source dataset to load data from.
        workers (int): Number of worker processes to spawn. Defaults to 4.
        verbose (bool): Whether to report progress verbosely. Defaults to False.
        nokill (bool): If True, don't kill old processes when restarting. Defaults to False.
        prefix (str): Directory prefix for the ZMQ socket. Defaults to "/tmp/_multi-".
    """

    def __init__(self, dataset, workers=4, verbose=False, nokill=False, prefix="/tmp/_multi-"):
        """Initialize the MultiLoader instance.

        Args:
            dataset: The source dataset to load data from.
            workers (int): Number of worker processes to spawn. Defaults to 4.
            verbose (bool): Whether to report progress verbosely. Defaults to False.
            nokill (bool): If True, don't kill old processes when restarting. Defaults to False.
            prefix (str): Directory prefix for the ZMQ socket. Defaults to "/tmp/_multi-".
        """
        self.dataset = dataset
        self.workers = workers
        self.verbose = verbose
        self.pids = []
        self.socket = None
        self.ctx = zmq.Context.instance()
        self.nokill = nokill
        self.prefix = prefix

    def kill(self):
        """Kill all worker processes and close the ZMQ socket."""
        for pid in self.pids:
            if pid is None:
                continue
            print("killing", pid)
            pid.kill()
            pid.join(1.0)
        self.pids = []
        if self.socket is not None:
            print("closing", self.socket)
            self.socket.close()
        self.socket = None

    def __iter__(self):
        """Return an iterator over this dataloader.

        This method sets up the ZMQ socket, spawns worker processes, and yields
        samples from the dataset.

        Yields:
            Sample: A sample from the dataset.

        Raises:
            None
        """
        if not self.nokill:
            self.kill()
        self.sockname = "ipc://" + self.prefix + str(uuid.uuid4())
        self.socket = self.ctx.socket(zmq.PULL)
        self.socket.bind(self.sockname)
        if self.verbose:
            print("#", self.sockname)
        self.pids = [None] * self.workers
        for index in range(self.workers):
            args = (self.dataset, self.sockname, index, self.workers)
            self.pids[index] = mp.Process(target=reader, args=args)
        all_pids.update(self.pids)
        for pid in self.pids:
            pid.start()
        count = 0
        while self.pids.count(None) < len(self.pids):
            data = self.socket.recv()
            sample = pickle.loads(data)
            if isinstance(sample, EOF):
                if self.verbose:
                    print("# subprocess finished", sample.index)
                self.pids[sample.index].join(1.0)
                self.pids[sample.index] = None
            else:
                yield sample
            count += 1