Skip to content

Commit 8f51c39

Browse files
mrshenlifacebook-github-bot
authored andcommitted
Improve torch.futures docs (pytorch#40245)
Summary: Pull Request resolved: pytorch#40245 Test Plan: Imported from OSS Differential Revision: D22126892 Pulled By: mrshenli fbshipit-source-id: e7d06b9b20ac8473cc6f0572dd4872096fd366c3
1 parent 13bd599 commit 8f51c39

File tree

1 file changed

+31
-7
lines changed

1 file changed

+31
-7
lines changed

torch/futures/__init__.py

Lines changed: 31 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,9 @@ def then(self, callback):
4949
>>> # The inserted callback will print the return value when
5050
>>> # receiving the response from "worker1"
5151
>>> cb_fut = fut.then(callback)
52-
>>> chain_cb_fut = cb_fut.then(lambda x : print(f"Chained cb done. {x.wait()}"))
52+
>>> chain_cb_fut = cb_fut.then(
53+
>>> lambda x : print(f"Chained cb done. {x.wait()}")
54+
>>> )
5355
>>> fut.set_result(5)
5456
>>>
5557
>>> # Outputs are:
@@ -91,14 +93,34 @@ def set_result(self, result):
9193

9294
def collect_all(futures):
9395
r"""
94-
Collects the Futures into a single combined Future that is completed
95-
when all of the sub-futures are completed.
96+
Collects the provided :class:`~torch.futures.Future` objects into a single
97+
combined :class:`~torch.futures.Future` that is completed when all of the
98+
sub-futures are completed.
9699
97100
Arguments:
98-
futures: a list of Futures
101+
futures (list): a list of :class:`~torch.futures.Future` objects.
99102
100103
Returns:
101-
Returns a Future object to a list of the passed in Futures.
104+
Returns a :class:`~torch.futures.Future` object to a list of the passed
105+
in Futures.
106+
107+
Example::
108+
>>> import torch
109+
>>>
110+
>>> fut0 = torch.futures.Future()
111+
>>> fut1 = torch.futures.Future()
112+
>>>
113+
>>> fut = torch.futures.collect_all([fut0, fut1])
114+
>>>
115+
>>> fut0.set_result(0)
116+
>>> fut1.set_result(1)
117+
>>>
118+
>>> fut_list = fut.wait()
119+
>>> print(f"fut0 result = {fut_list[0].wait()}")
120+
>>> print(f"fut1 result = {fut_list[1].wait()}")
121+
>>> # outputs:
122+
>>> # fut0 result = 0
123+
>>> # fut1 result = 1
102124
"""
103125
return torch._C._collect_all(futures)
104126

@@ -108,9 +130,11 @@ def wait_all(futures):
108130
the list of completed values.
109131
110132
Arguments:
111-
futures: a list of Futures
133+
futures (list): a list of :class:`~torch.futures.Future` object.
112134
113135
Returns:
114-
A list of the completed Future results
136+
A list of the completed :class:`~torch.futures.Future` results. This
137+
method will throw an error if ``wait`` on any
138+
:class:`~torch.futures.Future` throws.
115139
"""
116140
return [fut.wait() for fut in torch._C._collect_all(futures).wait()]

0 commit comments

Comments
 (0)