|
import contextlib |
|
from multiprocessing import Pool, RLock |
|
|
|
from tqdm.auto import tqdm |
|
|
|
from ..utils import experimental, logging |
|
|
|
|
|
logger = logging.get_logger(__name__) |
|
|
|
|
|
class ParallelBackendConfig: |
|
backend_name = None |
|
|
|
|
|
@experimental |
|
def parallel_map(function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func): |
|
""" |
|
**Experimental.** Apply a function to iterable elements in parallel, where the implementation uses either |
|
multiprocessing.Pool or joblib for parallelization. |
|
|
|
Args: |
|
function (`Callable[[Any], Any]`): Function to be applied to `iterable`. |
|
iterable (`list`, `tuple` or `np.ndarray`): Iterable elements to apply function to. |
|
num_proc (`int`): Number of processes (if no backend specified) or jobs (using joblib). |
|
types (`tuple`): Additional types (besides `dict` values) to apply `function` recursively to their elements. |
|
disable_tqdm (`bool`): Whether to disable the tqdm progressbar. |
|
desc (`str`): Prefix for the tqdm progressbar. |
|
single_map_nested_func (`Callable`): Map function that applies `function` to an element from `iterable`. |
|
Takes a tuple of function, data_struct, types, rank, disable_tqdm, desc as input, where data_struct is an |
|
element of `iterable`, and `rank` is used for progress bar. |
|
""" |
|
if ParallelBackendConfig.backend_name is None: |
|
return _map_with_multiprocessing_pool( |
|
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func |
|
) |
|
|
|
return _map_with_joblib( |
|
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func |
|
) |
|
|
|
|
|
def _map_with_multiprocessing_pool( |
|
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func |
|
): |
|
num_proc = num_proc if num_proc <= len(iterable) else len(iterable) |
|
split_kwds = [] |
|
for index in range(num_proc): |
|
div = len(iterable) // num_proc |
|
mod = len(iterable) % num_proc |
|
start = div * index + min(index, mod) |
|
end = start + div + (1 if index < mod else 0) |
|
split_kwds.append((function, iterable[start:end], batched, batch_size, types, index, disable_tqdm, desc)) |
|
|
|
if len(iterable) != sum(len(i[1]) for i in split_kwds): |
|
raise ValueError( |
|
f"Error dividing inputs iterable among processes. " |
|
f"Total number of objects {len(iterable)}, " |
|
f"length: {sum(len(i[1]) for i in split_kwds)}" |
|
) |
|
|
|
logger.info( |
|
f"Spawning {num_proc} processes for {len(iterable)} objects in slices of {[len(i[1]) for i in split_kwds]}" |
|
) |
|
initargs, initializer = None, None |
|
if not disable_tqdm: |
|
initargs, initializer = (RLock(),), tqdm.set_lock |
|
with Pool(num_proc, initargs=initargs, initializer=initializer) as pool: |
|
mapped = pool.map(single_map_nested_func, split_kwds) |
|
logger.info(f"Finished {num_proc} processes") |
|
mapped = [obj for proc_res in mapped for obj in proc_res] |
|
logger.info(f"Unpacked {len(mapped)} objects") |
|
|
|
return mapped |
|
|
|
|
|
def _map_with_joblib( |
|
function, iterable, num_proc, batched, batch_size, types, disable_tqdm, desc, single_map_nested_func |
|
): |
|
|
|
|
|
import joblib |
|
|
|
with joblib.parallel_backend(ParallelBackendConfig.backend_name, n_jobs=num_proc): |
|
return joblib.Parallel()( |
|
joblib.delayed(single_map_nested_func)((function, obj, batched, batch_size, types, None, True, None)) |
|
for obj in iterable |
|
) |
|
|
|
|
|
@experimental |
|
@contextlib.contextmanager |
|
def parallel_backend(backend_name: str): |
|
""" |
|
**Experimental.** Configures the parallel backend for parallelized dataset loading, which uses the parallelization |
|
implemented by joblib. |
|
|
|
Args: |
|
backend_name (str): Name of backend for parallelization implementation, has to be supported by joblib. |
|
|
|
Example usage: |
|
```py |
|
with parallel_backend('spark'): |
|
dataset = load_dataset(..., num_proc=2) |
|
``` |
|
""" |
|
ParallelBackendConfig.backend_name = backend_name |
|
|
|
if backend_name == "spark": |
|
from joblibspark import register_spark |
|
|
|
register_spark() |
|
|
|
|
|
|
|
|
|
try: |
|
yield |
|
finally: |
|
ParallelBackendConfig.backend_name = None |
|
|