|
|
|
import functools |
|
|
|
|
|
def async_execution(fn): |
|
r""" |
|
A decorator for a function indicating that the return value of the function |
|
is guaranteed to be a :class:`~torch.futures.Future` object and this |
|
function can run asynchronously on the RPC callee. More specifically, the |
|
callee extracts the :class:`~torch.futures.Future` returned by the wrapped |
|
function and installs subsequent processing steps as a callback to that |
|
:class:`~torch.futures.Future`. The installed callback will read the value |
|
from the :class:`~torch.futures.Future` when completed and send the |
|
value back as the RPC response. That also means the returned |
|
:class:`~torch.futures.Future` only exists on the callee side and is never |
|
sent through RPC. This decorator is useful when the wrapped function's |
|
(``fn``) execution needs to pause and resume due to, e.g., containing |
|
:meth:`~torch.distributed.rpc.rpc_async` or waiting for other signals. |
|
|
|
.. note:: To enable asynchronous execution, applications must pass the |
|
function object returned by this decorator to RPC APIs. If RPC detected |
|
attributes installed by this decorator, it knows that this function |
|
returns a ``Future`` object and will handle that accordingly. |
|
However, this does not mean this decorator has to be outmost one when |
|
defining a function. For example, when combined with ``@staticmethod`` |
|
or ``@classmethod``, ``@rpc.functions.async_execution`` needs to be the |
|
inner decorator to allow the target function be recognized as a static |
|
or class function. This target function can still execute asynchronously |
|
because, when accessed, the static or class method preserves attributes |
|
installed by ``@rpc.functions.async_execution``. |
|
|
|
|
|
Example:: |
|
The returned :class:`~torch.futures.Future` object can come from |
|
:meth:`~torch.distributed.rpc.rpc_async`, |
|
:meth:`~torch.futures.Future.then`, or :class:`~torch.futures.Future` |
|
constructor. The example below shows directly using the |
|
:class:`~torch.futures.Future` returned by |
|
:meth:`~torch.futures.Future.then`. |
|
|
|
>>> from torch.distributed import rpc |
|
>>> |
|
>>> # omitting setup and shutdown RPC |
|
>>> |
|
>>> # On all workers |
|
>>> @rpc.functions.async_execution |
|
>>> def async_add_chained(to, x, y, z): |
|
>>> # This function runs on "worker1" and returns immediately when |
|
>>> # the callback is installed through the `then(cb)` API. In the |
|
>>> # mean time, the `rpc_async` to "worker2" can run concurrently. |
|
>>> # When the return value of that `rpc_async` arrives at |
|
>>> # "worker1", "worker1" will run the lambda function accordingly |
|
>>> # and set the value for the previously returned `Future`, which |
|
>>> # will then trigger RPC to send the result back to "worker0". |
|
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
>>> lambda fut: fut.wait() + z |
|
>>> ) |
|
>>> |
|
>>> # On worker0 |
|
>>> # xdoctest: +SKIP |
|
>>> ret = rpc.rpc_sync( |
|
>>> "worker1", |
|
>>> async_add_chained, |
|
>>> args=("worker2", torch.ones(2), 1, 1) |
|
>>> ) |
|
>>> print(ret) # prints tensor([3., 3.]) |
|
|
|
When combined with TorchScript decorators, this decorator must be the |
|
outmost one. |
|
|
|
>>> from torch import Tensor |
|
>>> from torch.futures import Future |
|
>>> from torch.distributed import rpc |
|
>>> |
|
>>> # omitting setup and shutdown RPC |
|
>>> |
|
>>> # On all workers |
|
>>> @torch.jit.script |
|
>>> def script_add(x: Tensor, y: Tensor) -> Tensor: |
|
>>> return x + y |
|
>>> |
|
>>> @rpc.functions.async_execution |
|
>>> @torch.jit.script |
|
>>> def async_add(to: str, x: Tensor, y: Tensor) -> Future[Tensor]: |
|
>>> return rpc.rpc_async(to, script_add, (x, y)) |
|
>>> |
|
>>> # On worker0 |
|
>>> ret = rpc.rpc_sync( |
|
>>> "worker1", |
|
>>> async_add, |
|
>>> args=("worker2", torch.ones(2), 1) |
|
>>> ) |
|
>>> print(ret) # prints tensor([2., 2.]) |
|
|
|
When combined with static or class method, this decorator must be the |
|
inner one. |
|
|
|
>>> from torch.distributed import rpc |
|
>>> |
|
>>> # omitting setup and shutdown RPC |
|
>>> |
|
>>> # On all workers |
|
>>> class AsyncExecutionClass: |
|
>>> |
|
>>> @staticmethod |
|
>>> @rpc.functions.async_execution |
|
>>> def static_async_add(to, x, y, z): |
|
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
>>> lambda fut: fut.wait() + z |
|
>>> ) |
|
>>> |
|
>>> @classmethod |
|
>>> @rpc.functions.async_execution |
|
>>> def class_async_add(cls, to, x, y, z): |
|
>>> ret_fut = torch.futures.Future() |
|
>>> rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
>>> lambda fut: ret_fut.set_result(fut.wait() + z) |
|
>>> ) |
|
>>> return ret_fut |
|
>>> |
|
>>> @rpc.functions.async_execution |
|
>>> def bound_async_add(self, to, x, y, z): |
|
>>> return rpc.rpc_async(to, torch.add, args=(x, y)).then( |
|
>>> lambda fut: fut.wait() + z |
|
>>> ) |
|
>>> |
|
>>> # On worker0 |
|
>>> ret = rpc.rpc_sync( |
|
>>> "worker1", |
|
>>> AsyncExecutionClass.static_async_add, |
|
>>> args=("worker2", torch.ones(2), 1, 2) |
|
>>> ) |
|
>>> print(ret) # prints tensor([4., 4.]) |
|
>>> |
|
>>> ret = rpc.rpc_sync( |
|
>>> "worker1", |
|
>>> AsyncExecutionClass.class_async_add, |
|
>>> args=("worker2", torch.ones(2), 1, 2) |
|
>>> ) |
|
>>> print(ret) # prints tensor([4., 4.]) |
|
|
|
This decorator also works with RRef helpers, i.e., . |
|
:meth:`torch.distributed.rpc.RRef.rpc_sync`, |
|
:meth:`torch.distributed.rpc.RRef.rpc_async`, and |
|
:meth:`torch.distributed.rpc.RRef.remote`. |
|
|
|
>>> from torch.distributed import rpc |
|
>>> |
|
>>> # reuse the AsyncExecutionClass class above |
|
>>> rref = rpc.remote("worker1", AsyncExecutionClass) |
|
>>> ret = rref.rpc_sync().static_async_add("worker2", torch.ones(2), 1, 2) |
|
>>> print(ret) # prints tensor([4., 4.]) |
|
>>> |
|
>>> rref = rpc.remote("worker1", AsyncExecutionClass) |
|
>>> ret = rref.rpc_async().static_async_add("worker2", torch.ones(2), 1, 2).wait() |
|
>>> print(ret) # prints tensor([4., 4.]) |
|
>>> |
|
>>> rref = rpc.remote("worker1", AsyncExecutionClass) |
|
>>> ret = rref.remote().static_async_add("worker2", torch.ones(2), 1, 2).to_here() |
|
>>> print(ret) # prints tensor([4., 4.]) |
|
""" |
|
|
|
@functools.wraps(fn) |
|
def wrapper(*args, **kwargs): |
|
return fn(*args, **kwargs) |
|
|
|
|
|
wrapper._wrapped_async_rpc_function = fn |
|
return wrapper |
|
|