Skip to content

Commit

Permalink
Apply sharding based on priority & combine DistInfo and ExtraInfo (py…
Browse files Browse the repository at this point in the history
…torch#916)

Summary:
After pytorch/pytorch#88424 is landed, we are able to invoke `apply_sharding` by sharding levels (distributed or multiprocessing). Then, we are able to give fine-control on sharding by `ReadingService`.
- For `DistributedReadingService`, we will only set sharding on the distributed level
- For `PrototypeMPReadingService`, we will set distributed sharding in the main process and set mp sharding in the worker processes. Previously, we set sharding in each worker process based on both distributed and mp information.
  - `worker_init_fn` doesn't need `DistInfo` anymore. As, the `DataPipe` has been distributed sharded in the main process.
  - Combine `DistInfo` and `ExtraInfo` for `worker_reset_fn` to synchronize the distributed seeds across distributed workers and set worker-local seeds based on both distributed and mp information.

Pull Request resolved: pytorch#916

Reviewed By: mingyuzh

Differential Revision: D41776719

Pulled By: ejguan

fbshipit-source-id: 6042da09f5e83019d536696237028ea20e67d110
  • Loading branch information
ejguan authored and facebook-github-bot committed Dec 7, 2022
1 parent c87cfb8 commit 14888a3
Show file tree
Hide file tree
Showing 5 changed files with 54 additions and 65 deletions.
3 changes: 1 addition & 2 deletions docs/source/dataloader2.rst
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ Note:
- :class:`torchdata.datapipes.map.SequenceWrapper`: ``torch.utils.data.Dataset``
- :class:`torchdata.datapipes.iter.IterableWrapper`: ``torch.utils.data.IterableDataset``

Both custom ``worker_init_fn`` and ``worker_reset_fn`` require the following three arguments:
- :class:`torchdata.dataloader2.utils.DistInfo`
Both custom ``worker_init_fn`` and ``worker_reset_fn`` require the following two arguments:
- :class:`torchdata.dataloader2.utils.WorkerInfo`
- ``DataPipe``

Expand Down
10 changes: 7 additions & 3 deletions test/dataloader2/test_dataloader2.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,8 @@

import torch

from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES

from torchdata.dataloader2 import (
communication,
DataLoader2,
Expand Down Expand Up @@ -352,13 +354,15 @@ def clean_me(process, req_queue, res_queue):

class PrototypeMultiProcessingReadingServiceTest(TestCase):
@staticmethod
def _worker_init_fn(datapipe, dist_info, worker_info):
def _worker_init_fn(datapipe, worker_info):
datapipe = datapipe.sharding_filter()
torch.utils.data.graph_settings.apply_sharding(datapipe, worker_info.num_workers, worker_info.worker_id)
torch.utils.data.graph_settings.apply_sharding(
datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING
)
return datapipe

@staticmethod
def _worker_reset_fn(datapipe, dist_info, worker_info):
def _worker_reset_fn(datapipe, worker_info):
worker_seed_generator = torch.Generator()
worker_seed_generator.manual_seed(123)
torch.utils.data.graph_settings.apply_random_seed(
Expand Down
53 changes: 23 additions & 30 deletions torchdata/dataloader2/reading_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,13 @@
import torch.distributed as dist

from torch.utils.data import DataLoader
from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES

from torchdata._constants import default_dl2_worker_join_timeout_in_s, default_timeout_in_s
from torchdata.dataloader2 import communication
from torchdata.dataloader2.graph import DataPipe
from torchdata.dataloader2.utils import (
DistInfo,
generate_random_scalar_tensor,
process_init_fn,
process_reset_fn,
WorkerInfo,
)
from torchdata.dataloader2.utils.worker import _ExtraInfo
from torchdata.dataloader2.utils import generate_random_scalar_tensor, process_init_fn, process_reset_fn, WorkerInfo
from torchdata.dataloader2.utils.worker import _DistInfo
from torchdata.datapipes.iter import FullSync, IterableWrapper, IterDataPipe


Expand Down Expand Up @@ -186,34 +181,35 @@ class PrototypeMultiProcessingReadingService(ReadingServiceInterface):
main_prefetch_cnt: (int, 10 by default): Number of data will be prefetched
at the end of the whole pipeline in the main process.
worker_init_fn: (Callable, optional): Function to be called when each worker
process launches with ``DistInfo``, ``WorkerInfo`` and ``DataPipe``
process launches with ``WorkerInfo`` and ``DataPipe``
as the expected arguments.
worker_reset_fn: (Callable, optional): Function to be called at the beginning
of each epoch in each worker process with ``DistInfo``, ``WorkerInfo``
of each epoch in each worker process with ``WorkerInfo``
and ``DataPipe`` as the expected arguments.
"""
num_workers: int
multiprocessing_context: Optional[str]
worker_prefetch_cnt: int
main_prefetch_cnt: int
worker_init_fn: Optional[Callable[[DataPipe, DistInfo, WorkerInfo], DataPipe]]
worker_reset_fn: Optional[Callable[[DataPipe, DistInfo, WorkerInfo], DataPipe]]
worker_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]]
worker_reset_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]]
processes: List
datapipes: List
end_datapipe: Optional[DataPipe]
_mp: bool
_pg: Optional[dist.ProcessGroup]
_dist_info: DistInfo
_world_size: int
_rank: int

def __init__(
self,
num_workers: int = 0,
multiprocessing_context: Optional[str] = None,
worker_prefetch_cnt: int = 10,
main_prefetch_cnt: int = 10,
worker_init_fn: Optional[Callable[[DataPipe, DistInfo, WorkerInfo], DataPipe]] = None,
worker_reset_fn: Optional[Callable[[DataPipe, DistInfo, WorkerInfo], DataPipe]] = None,
worker_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
worker_reset_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
) -> None:
self.num_workers = num_workers
if multiprocessing_context is not None:
Expand All @@ -231,7 +227,8 @@ def __init__(
self.end_datapipe = None
self._mp = num_workers > 0
self._pg = None
self._dist_info = DistInfo(1, 0)
self._world_size = 1
self._rank = 0

def initialize(self, datapipe: DataPipe) -> DataPipe:
r"""
Expand All @@ -240,14 +237,16 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
creates subprocesses.
"""
if dist.is_available() and dist.is_initialized():
_world_size = dist.get_world_size()
_rank = dist.get_rank()
self._dist_info = DistInfo(_world_size, _rank)
self._world_size = dist.get_world_size()
self._rank = dist.get_rank()
self._pg = dist.new_group(backend="gloo")
torch.utils.data.graph_settings.apply_sharding(
datapipe, self._world_size, self._rank, SHARDING_PRIORITIES.DISTRIBUTED
)
if not self._mp:
# TODO(616): Warn and recommend usage of InProcessReadingService
worker_info = WorkerInfo(1, 0)
process_init_fn(datapipe, self._dist_info, worker_info, self.worker_init_fn)
process_init_fn(datapipe, worker_info, self.worker_init_fn)
self.end_datapipe = datapipe
return datapipe

Expand All @@ -256,9 +255,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:

for worker_id in range(self.num_workers):
worker_info = WorkerInfo(self.num_workers, worker_id)
call_on_process_init = partial(
process_init_fn, dist_info=self._dist_info, worker_info=worker_info, custom_init_fn=self.worker_init_fn
)
call_on_process_init = partial(process_init_fn, worker_info=worker_info, custom_init_fn=self.worker_init_fn)
ctx = mp.get_context(self.multiprocessing_context)
# Process contains a ProtocolServer
(process, req_queue, res_queue) = communication.eventloop.SpawnProcessForDataPipeline(
Expand Down Expand Up @@ -300,10 +297,8 @@ def initialize_iteration(self) -> None:
else:
end_datapipe = self.end_datapipe
# Send the shared seed to subprocesses
extra_info = _ExtraInfo(shared_seed_int)
call_on_epoch_reset = partial(
process_reset_fn, dist_info=self._dist_info, extra_info=extra_info, custom_reset_fn=self.worker_reset_fn
)
dist_info = _DistInfo(shared_seed_int, self._world_size, self._rank)
call_on_epoch_reset = partial(process_reset_fn, dist_info=dist_info, custom_reset_fn=self.worker_reset_fn)
end_datapipe.reset_epoch(call_on_epoch_reset)
end_datapipe.reset()
# In-process (num_workers == 0)
Expand Down Expand Up @@ -438,9 +433,7 @@ def initialize(self, datapipe: DataPipe) -> DataPipe:
self._rank = dist.get_rank()
self._pg = dist.new_group(backend="gloo", timeout=timedelta(seconds=self._timeout))
torch.utils.data.graph_settings.apply_sharding(
datapipe,
self._world_size,
self._rank,
datapipe, self._world_size, self._rank, SHARDING_PRIORITIES.DISTRIBUTED
)
# Only append FullSyncIterDataPipe if it's not presented at the end of the pipeline
if not isinstance(datapipe, FullSync):
Expand Down
3 changes: 1 addition & 2 deletions torchdata/dataloader2/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,11 +6,10 @@


from torchdata.dataloader2.utils.random import generate_random_int, generate_random_scalar_tensor
from torchdata.dataloader2.utils.worker import DistInfo, process_init_fn, process_reset_fn, WorkerInfo
from torchdata.dataloader2.utils.worker import process_init_fn, process_reset_fn, WorkerInfo


__all__ = [
"DistInfo",
"WorkerInfo",
"generate_random_int",
"generate_random_scalar_tensor",
Expand Down
50 changes: 22 additions & 28 deletions torchdata/dataloader2/utils/worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@

import torch

from torch.utils.data.datapipes.iter.grouping import SHARDING_PRIORITIES

from torchdata.dataloader2.graph import DataPipe
from torchdata.dataloader2.utils import generate_random_int
from torchdata.datapipes.iter import IterDataPipe
Expand All @@ -24,19 +26,6 @@
HAS_NUMPY = False


@dataclass(frozen=True)
class DistInfo:
r"""
Message class for keeping track of distributed information.
Args:
world_size (int): Total number of distributed nodes
rank (int): Distributed rank for the current distributed node
"""
world_size: int = 1
rank: int = 0


@dataclass(frozen=True)
class WorkerInfo:
r"""
Expand All @@ -51,39 +40,44 @@ class WorkerInfo:


@dataclass(frozen=True)
class _ExtraInfo:
class _DistInfo:
r"""
Message class for extra arguments.
Message class for distribtued information.
Args:
shared_seed: Distributed shared random seed
world_size (int): Total number of distributed nodes
rank (int): Distributed rank for the current distributed node
"""
shared_seed: int
world_size: int = 1
rank: int = 0


def process_init_fn(
datapipe: DataPipe,
dist_info: DistInfo,
worker_info: WorkerInfo,
custom_init_fn: Optional[Callable[[DataPipe, DistInfo, WorkerInfo], DataPipe]] = None,
custom_init_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
) -> DataPipe:
r"""
Based on the distributed and worker information, shard the ``DataPipe`` graph dynamically.
Based on the worker information, shard the ``DataPipe`` graph dynamically.
"""
global_worker_id = worker_info.worker_id * dist_info.world_size + dist_info.rank
total_num_workers = worker_info.num_workers * dist_info.world_size
torch.utils.data.graph_settings.apply_sharding(datapipe, total_num_workers, global_worker_id)
torch.utils.data.graph_settings.apply_sharding(
datapipe, worker_info.num_workers, worker_info.worker_id, SHARDING_PRIORITIES.MULTIPROCESSING
)

if custom_init_fn is not None:
datapipe = custom_init_fn(datapipe, dist_info, worker_info)
datapipe = custom_init_fn(datapipe, worker_info)
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))

return datapipe


def process_reset_fn(
datapipe: DataPipe,
dist_info: DistInfo,
worker_info: WorkerInfo,
extra_info: _ExtraInfo,
custom_reset_fn: Optional[Callable[[DataPipe, DistInfo, WorkerInfo], DataPipe]] = None,
dist_info: _DistInfo,
custom_reset_fn: Optional[Callable[[DataPipe, WorkerInfo], DataPipe]] = None,
) -> DataPipe:
r"""
Based on the distributed shared random seed and worker id, this function is used to
Expand All @@ -92,14 +86,14 @@ def process_reset_fn(
"""
# This function will receive worker local copy of datapipe and reset function from ``initialize_iteration``
worker_seed_generator = torch.Generator()
worker_seed_generator.manual_seed(extra_info.shared_seed)
worker_seed_generator.manual_seed(dist_info.shared_seed)
torch.utils.data.graph_settings.apply_random_seed(
datapipe,
worker_seed_generator,
)
# Set different seeds across distributed workers
global_worker_id = worker_info.worker_id * dist_info.world_size + dist_info.rank
worker_seed_generator.manual_seed(extra_info.shared_seed + global_worker_id)
worker_seed_generator.manual_seed(dist_info.shared_seed + global_worker_id)

py_seed = generate_random_int(worker_seed_generator)
random.seed(py_seed)
Expand All @@ -115,7 +109,7 @@ def process_reset_fn(
numpy.random.seed(np_seed)

if custom_reset_fn is not None:
datapipe = custom_reset_fn(datapipe, dist_info, worker_info)
datapipe = custom_reset_fn(datapipe, worker_info)
assert isinstance(datapipe, (IterDataPipe, MapDataPipe))

return datapipe

0 comments on commit 14888a3

Please sign in to comment.