Skip to content

Commit 2393bab

Browse files
lwfacebook-github-bot
authored andcommitted
[TensorPipe] Update documentation (pytorch#40222)
Summary: Pull Request resolved: pytorch#40222 Mention the TensorPipe agent in the RPC docs and give users the information they need to choose which agent to use. ghstack-source-id: 106225711 Test Plan: Export to GitHub, build locally and try out the docs. Differential Revision: D22116494 fbshipit-source-id: 30703ba8410c40f64e785f60d71dfd9faa8de4a1
1 parent 8315bb2 commit 2393bab

File tree

4 files changed

+175
-43
lines changed

4 files changed

+175
-43
lines changed

docs/source/rpc.rst

Lines changed: 118 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -82,10 +82,7 @@ RPC
8282
Before using RPC and distributed autograd primitives, initialization must take
8383
place. To initialize the RPC framework we need to use
8484
:meth:`~torch.distributed.rpc.init_rpc` which would initialize the RPC
85-
framework, RRef framework and distributed autograd. By default, this will also
86-
initialize the ``ProcessGroup`` (:meth:`~torch.distributed.init_process_group`)
87-
backend for RPC communication. The ``ProcessGroup`` backend internally uses gloo
88-
for communication.
85+
framework, RRef framework and distributed autograd.
8986

9087
.. automodule:: torch.distributed.rpc
9188
.. autofunction:: init_rpc
@@ -109,9 +106,6 @@ and move it to the desired devices on the callee if necessary.
109106
.. autofunction:: shutdown
110107
.. autoclass:: WorkerInfo
111108
:members:
112-
.. autoclass:: ProcessGroupRpcBackendOptions
113-
:members:
114-
:inherited-members:
115109

116110

117111
The RPC package also provides decorators which allow applications to specify
@@ -122,8 +116,124 @@ how a given function should be treated on the callee side.
122116

123117
.. autofunction:: torch.distributed.rpc.functions.async_execution
124118

125-
.. _rref:
126119

120+
.. _rpc-backends:
121+
122+
Backends
123+
^^^^^^^^
124+
125+
The RPC module can leverage different backends to perform the communication
126+
between the nodes. The backend to be used can be specified in the
127+
:func:`~torch.distributed.rpc.init_rpc` function, by passing a certain value of
128+
the :class:`~torch.distributed.rpc.BackendType` enum. Regardless of what backend
129+
is used, the rest of the RPC API won't change. Each backend also defines its own
130+
subclass of the :class:`~torch.distributed.rpc.RpcBackendOptions` class, an
131+
instance of which can also be passed to :func:`~torch.distributed.rpc.init_rpc`
132+
to configure the backend's behavior.
133+
134+
.. autoclass:: BackendType
135+
136+
.. autoclass:: RpcBackendOptions
137+
:members:
138+
139+
140+
Process Group Backend
141+
"""""""""""""""""""""
142+
143+
The Process Group agent, which is the default, instantiates a process group from
144+
the :mod:`~torch.distributed` module and utilizes its point-to-point
145+
communication capabilities to send RPC messages across. Internally, the process
146+
group uses `the Gloo library <https://github.com/facebookincubator/gloo/>`_.
147+
148+
Gloo has been hardened by years of extensive use in PyTorch and is thus very
149+
reliable. However, as it was designed to perform collective communication, it
150+
may not always be the best fit for RPC. For example, each networking operation
151+
is synchronous and blocking, which means that it cannot be run in parallel with
152+
others. Moreover, it opens a connection between all pairs of nodes, and brings
153+
down all of them when one fails, thus reducing the resiliency and the elasticity
154+
of the system.
155+
156+
Example::
157+
158+
>>> import os
159+
>>> from torch.distributed import rpc
160+
>>> os.environ['MASTER_ADDR'] = 'localhost'
161+
>>> os.environ['MASTER_PORT'] = '29500'
162+
>>>
163+
>>> rpc.init_rpc(
164+
>>> "worker1",
165+
>>> rank=0,
166+
>>> world_size=2,
167+
>>> rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
168+
>>> num_send_recv_threads=16,
169+
>>> rpc_timeout=20 # 20 second timeout
170+
>>> )
171+
>>> )
172+
>>>
173+
>>> # omitting init_rpc invocation on worker2
174+
175+
176+
.. autoclass:: ProcessGroupRpcBackendOptions
177+
:members:
178+
:inherited-members:
179+
180+
181+
TensorPipe Backend
182+
""""""""""""""""""
183+
184+
.. warning::
185+
The TensorPipe backend is a **beta feature**.
186+
187+
The TensorPipe agent leverages `the TensorPipe library
188+
<https://github.com/pytorch/tensorpipe>`_, which provides a natively
189+
point-to-point communication primitive specifically suited for machine learning
190+
that fundamentally addresses some of the limitations of Gloo. Compared to Gloo,
191+
it has the advantage of being asynchronous, which allows a large number of
192+
transfers to occur simultaneously, each at their own speed, without blocking
193+
each other. It will only open pipes between pairs of nodes when needed, on
194+
demand, and when one node fails only its incident pipes will be closed, while
195+
all other ones will keep working as normal. In addition, it is able to support
196+
multiple different transports (TCP, of course, but also shared memory, NVLink,
197+
InfiniBand, ...) and can automatically detect their availability and negotiate
198+
the best transport to use for each pipe.
199+
200+
The TensorPipe backend has been introduced in PyTorch v1.6 and is being actively
201+
developed. At the moment, it only supports CPU tensors, with GPU support coming
202+
soon. It comes with a TCP-based transport, just like Gloo. It is also able to
203+
automatically chunk and multiplex large tensors over multiple sockets and
204+
threads in order to achieve very high bandwidths. In addition to that, it packs
205+
two Linux-specific transports for communication between processes on a same
206+
machine (one based on ringbuffers stored in shared memory, the other on the
207+
cross-memory attach syscalls) which can achieve lower latencies than TCP.
208+
The agent will be able to pick the best transport on its own, with no
209+
intervention required.
210+
211+
Example::
212+
213+
>>> import os
214+
>>> from torch.distributed import rpc
215+
>>> os.environ['MASTER_ADDR'] = 'localhost'
216+
>>> os.environ['MASTER_PORT'] = '29500'
217+
>>>
218+
>>> rpc.init_rpc(
219+
>>> "worker1",
220+
>>> rank=0,
221+
>>> world_size=2,
222+
>>> backend=rpc.BackendType.TENSORPIPE,
223+
>>> rpc_backend_options=rpc.TensorPipeRpcBackendOptions(
224+
>>> num_worker_threads=8,
225+
>>> rpc_timeout=20 # 20 second timeout
226+
>>> )
227+
>>> )
228+
>>>
229+
>>> # omitting init_rpc invocation on worker2
230+
231+
.. autoclass:: TensorPipeRpcBackendOptions
232+
:members:
233+
:inherited-members:
234+
235+
236+
.. _rref:
127237

128238
RRef
129239
----

torch/csrc/distributed/rpc/init.cpp

Lines changed: 31 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -397,25 +397,6 @@ PyObject* rpc_init(PyObject* /* unused */) {
397397
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
398398
init_method (str, optional): The URL to initialize
399399
``ProcessGroupGloo`` (default: ``env://``).
400-
401-
402-
Example::
403-
>>> import datetime, os
404-
>>> from torch.distributed import rpc
405-
>>> os.environ['MASTER_ADDR'] = 'localhost'
406-
>>> os.environ['MASTER_PORT'] = '29500'
407-
>>>
408-
>>> rpc.init_rpc(
409-
>>> "worker1",
410-
>>> rank=0,
411-
>>> world_size=2,
412-
>>> rpc_backend_options=rpc.ProcessGroupRpcBackendOptions(
413-
>>> num_send_recv_threads=16,
414-
>>> rpc_timeout=20 # 20 second timeout
415-
>>> )
416-
>>> )
417-
>>>
418-
>>> # omitting init_rpc invocation on worker2
419400
)")
420401
.def(
421402
py::init<int, float, std::string>(),
@@ -473,7 +454,30 @@ PyObject* rpc_init(PyObject* /* unused */) {
473454

474455
// Base class: torch.distributed.rpc.RpcBackendOptions.
475456
py::class_<TensorPipeRpcBackendOptions>(
476-
module, "TensorPipeRpcBackendOptions", rpcBackendOptions)
457+
module,
458+
"TensorPipeRpcBackendOptions",
459+
rpcBackendOptions,
460+
R"(
461+
The backend options for
462+
:class:`~torch.distributed.rpc.TensorPipeAgent`, derived from
463+
:class:`~torch.distributed.rpc.RpcBackendOptions`.
464+
465+
Arguments:
466+
num_worker_threads (int, optional): The number of threads in the
467+
thread-pool used by
468+
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
469+
requests (default: 16).
470+
rpc_timeout (float, optional): The default timeout, in seconds,
471+
for RPC requests (default: 60 seconds). If the RPC has not
472+
completed in this timeframe, an exception indicating so will
473+
be raised. Callers can override this timeout for individual
474+
RPCs in :meth:`~torch.distributed.rpc.rpc_sync` and
475+
:meth:`~torch.distributed.rpc.rpc_async` if necessary.
476+
init_method (str, optional): The URL to initialize the distributed
477+
store used for rendezvous. It takes any value accepted for the
478+
same argument of :meth:`~torch.distributed.init_process_group`
479+
(default: ``env://``).
480+
)")
477481
.def(
478482
py::init<
479483
int,
@@ -487,7 +491,13 @@ PyObject* rpc_init(PyObject* /* unused */) {
487491
py::arg("rpc_timeout") = kDefaultRpcTimeoutSeconds,
488492
py::arg("init_method") = kDefaultInitMethod)
489493
.def_readwrite(
490-
"num_worker_threads", &TensorPipeRpcBackendOptions::numWorkerThreads);
494+
"num_worker_threads",
495+
&TensorPipeRpcBackendOptions::numWorkerThreads,
496+
R"(
497+
The number of threads in the thread-pool used by
498+
:class:`~torch.distributed.rpc.TensorPipeAgent` to execute
499+
requests.
500+
)");
491501

492502
module.attr("_DEFAULT_NUM_WORKER_THREADS") =
493503
py::cast(kDefaultNumWorkerThreads);

torch/distributed/rpc/__init__.py

Lines changed: 16 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@ def is_available():
1818
if is_available():
1919
from . import api, backend_registry, functions, _set_profiler_node_id
2020
from .api import * # noqa: F401
21+
from .backend_registry import BackendType
2122
from .server_process_global_profiler import (
2223
_server_process_global_profile,
2324
)
2425
import torch.distributed.autograd as dist_autograd
2526

2627
def init_rpc(
2728
name,
28-
backend=backend_registry.BackendType.PROCESS_GROUP,
29+
backend=BackendType.PROCESS_GROUP,
2930
rank=-1,
3031
world_size=None,
3132
rpc_backend_options=None,
@@ -38,27 +39,28 @@ def init_rpc(
3839
process ready to send and receive RPCs.
3940
4041
Arguments:
41-
backend (Enum): type of RPC backend implementation. Currently,
42-
process group backend is the only available backend
43-
implementation. (default: ``RpcBackend.PROCESS_GROUP``).
42+
backend (BackendType, optional): The type of RPC backend
43+
implementation. Supported values include
44+
``BackendType.PROCESS_GROUP`` (the default) and
45+
``BackendType.TENSORPIPE``. See :ref:`rpc-backends` for more
46+
information.
4447
name (str): a globally unique name of this node. (e.g.,
4548
``Trainer3``, ``ParameterServer2``, ``Master``, ``Worker1``)
4649
Name can only contain number, alphabet, underscore, and/or dash,
4750
and must be shorter than 128 characters.
4851
rank (int): a globally unique id/rank of this node.
4952
world_size (int): The number of workers in the group.
50-
rpc_backend_options (RpcBackendOptions): The options passed to
51-
RpcAgent constructor. It contains RpcAgent specific
52-
initialization configurations. By default, it contains
53-
``rpc_timeout = timedelta(seconds=60)``,
54-
``init_method = "env://"``, ``num_send_recv_threads = 4`` for
55-
process group agent. If using the default
56-
``rpc_backend_options``, RPC would initialize the underlying
57-
process group backend using ``init_method = "env://"``,
53+
rpc_backend_options (RpcBackendOptions, optional): The options
54+
passed to the RpcAgent constructor. It must be an agent-specific
55+
subclass of :class:`~torch.distributed.rpc.RpcBackendOptions`
56+
and contains agent-specific initialization configurations. By
57+
default, for all agents, it sets the default timeout to 60
58+
seconds and performs the rendezvous with an underlying process
59+
group initialized using ``init_method = "env://"``,
5860
meaning that environment variables ``MASTER_ADDR`` and
5961
``MASTER_PORT`` needs to be set properly. See
60-
:class:`~torch.distributed.rpc.ProcessGroupRpcBackendOptions`
61-
for examples.
62+
:ref:`rpc-backends` for more information and find which options
63+
are available.
6264
"""
6365

6466
if not rpc_backend_options:

torch/distributed/rpc/backend_registry.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,18 @@ def _backend_type_repr(self):
1818
return "BackendType." + self.name
1919

2020

21+
_backend_type_doc = """
22+
An enum class of available backends.
23+
24+
PyTorch ships with two builtin backends: ``BackendType.PROCESS_GROUP`` and
25+
``BackendType.TENSORPIPE``. Additional ones can be registered using the
26+
:func:`~torch.distributed.rpc.backend_registry.register_backend` function.
27+
"""
28+
2129
# Create an enum type, `BackendType`, with empty members.
2230
BackendType = enum.Enum(value="BackendType", names={})
2331
BackendType.__repr__ = _backend_type_repr
32+
BackendType.__doc__ = _backend_type_doc
2433

2534
def backend_registered(backend_name):
2635
"""
@@ -65,6 +74,7 @@ def register_backend(
6574
)
6675
BackendType = enum.Enum(value="BackendType", names=extended_enum_dict)
6776
BackendType.__repr__ = _backend_type_repr
77+
BackendType.__doc__ = _backend_type_doc
6878
return BackendType[backend_name]
6979

7080

0 commit comments

Comments
 (0)