Skip to content

Commit

Permalink
downcast indices in TokenBlockDataset (facebookresearch#1647)
Browse files Browse the repository at this point in the history
Summary:
### Measurements
TLDR: This saves ~8% CPU RAM for training tiny model on medium sized dataset (11GB on disk)

Command below:

```
+---------------------+----------------+---------+--------+
| fname               |   cpu_mem_used |     wps |    ppl |
+=====================+================+=========+========+

+---------------------+----------------+---------+--------+
| branch_nw8_2gpu.log |          25.41 | 54721   | 429.1  |
+---------------------+----------------+---------+--------+
+---------------------+----------------+---------+--------+
| master_nw8_2gpu.log |          27.53 | 52833.1 | 429.1  |
+---------------------+----------------+---------+--------+
```

### Command

```
base_cmd () {
  dd=$1
  shift
  fairseq-train --fp16 $dd \
            --task language_modeling \
            --arch transformer_lm_gpt2_tiny \
            --sample-break-mode complete --tokens-per-sample 512 \
            --optimizer adam --clip-norm 0.0 --lr 0.0005 \
            --batch-size 1 \
            --max-update 200 --max-epoch 1 \
            --log-format simple --log-interval 100 \
            --restore-file x.pt --no-save \
            --skip-invalid-size-inputs-valid-test --disable-validation $@
}
CUDA_VISIBLE_DEVICES=0,1 base_cmd /private/home/sshleifer/data-bin/stories_mmap --num-workers 8
```

Pull Request resolved: fairinternal/fairseq-py#1647

Reviewed By: myleott

Differential Revision: D26628861

Pulled By: sshleifer

fbshipit-source-id: 142afe0358d1c4cae448828ba811b211406509d7
  • Loading branch information
sshleifer authored and facebook-github-bot committed Feb 24, 2021
1 parent c3d2bee commit 55e48f1
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 16 deletions.
37 changes: 25 additions & 12 deletions fairseq/data/indexed_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,21 @@

from . import FairseqDataset

from typing import Union

def __best_fitting_dtype(vocab_size=None):
if vocab_size is not None and vocab_size < 65500:

def best_fitting_uint_dtype(
max_int_to_represent,
) -> Union[np.uint16, np.uint32, np.uint64]:

if max_int_to_represent is None:
return np.uint32 # Safe guess
elif max_int_to_represent < 65500:
return np.uint16
elif max_int_to_represent < 4294967295:
return np.uint32
else:
return np.int32
return np.uint64


def get_available_dataset_impl():
Expand Down Expand Up @@ -48,7 +57,7 @@ def infer_dataset_impl(path):
def make_builder(out_file, impl, vocab_size=None):
if impl == "mmap":
return MMapIndexedDatasetBuilder(
out_file, dtype=__best_fitting_dtype(vocab_size)
out_file, dtype=best_fitting_uint_dtype(vocab_size)
)
elif impl == "fasta":
raise NotImplementedError
Expand Down Expand Up @@ -92,7 +101,7 @@ def write_longs(f, a):
f.write(np.array(a, dtype=np.int64))


dtypes = {
_code_to_dtype = {
1: np.uint8,
2: np.int8,
3: np.int16,
Expand All @@ -101,12 +110,14 @@ def write_longs(f, a):
6: np.float,
7: np.double,
8: np.uint16,
9: np.uint32,
10: np.uint64,
}


def code(dtype):
for k in dtypes.keys():
if dtypes[k] == dtype:
def _dtype_header_code(dtype) -> int:
for k in _code_to_dtype.keys():
if _code_to_dtype[k] == dtype:
return k
raise ValueError(dtype)

Expand Down Expand Up @@ -141,7 +152,7 @@ def read_index(self, path):
version = f.read(8)
assert struct.unpack("<Q", version) == (1,)
code, self.element_size = struct.unpack("<QQ", f.read(16))
self.dtype = dtypes[code]
self.dtype = _code_to_dtype[code]
self._len, self.s = struct.unpack("<QQ", f.read(16))
self.dim_offsets = read_longs(f, self._len + 1)
self.data_offsets = read_longs(f, self._len + 1)
Expand Down Expand Up @@ -348,7 +359,9 @@ def finalize(self, index_file):
index = open(index_file, "wb")
index.write(b"TNTIDX\x00\x00")
index.write(struct.pack("<Q", 1))
index.write(struct.pack("<QQ", code(self.dtype), self.element_size))
index.write(
struct.pack("<QQ", _dtype_header_code(self.dtype), self.element_size)
)
index.write(struct.pack("<QQ", len(self.data_offsets) - 1, len(self.sizes)))
write_longs(index, self.dim_offsets)
write_longs(index, self.data_offsets)
Expand All @@ -374,7 +387,7 @@ def __enter__(self):

self._file.write(cls._HDR_MAGIC)
self._file.write(struct.pack("<Q", 1))
self._file.write(struct.pack("<B", code(dtype)))
self._file.write(struct.pack("<B", _dtype_header_code(dtype)))

return self

Expand Down Expand Up @@ -419,7 +432,7 @@ def __init__(self, path):
assert (1,) == version

(dtype_code,) = struct.unpack("<B", stream.read(1))
self._dtype = dtypes[dtype_code]
self._dtype = _code_to_dtype[dtype_code]
self._dtype_size = self._dtype().itemsize

self._len = struct.unpack("<Q", stream.read(8))[0]
Expand Down
11 changes: 7 additions & 4 deletions fairseq/data/token_block_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import numpy as np
import torch
from fairseq.data import FairseqDataset, plasma_utils

from fairseq.data.indexed_dataset import best_fitting_uint_dtype

class TokenBlockDataset(FairseqDataset):
"""Break a Dataset of tokens into blocks.
Expand Down Expand Up @@ -98,9 +98,12 @@ def __init__(
sizes,
slice_indices,
)
self._slice_indices = plasma_utils.PlasmaArray(slice_indices)
self._sizes = plasma_utils.PlasmaArray(self._sizes)
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index)
size_dtype = np.uint16 if block_size < 65535 else np.uint32
slice_indices_dtype = best_fitting_uint_dtype(slice_indices[-1].max())

self._slice_indices = plasma_utils.PlasmaArray(slice_indices.astype(slice_indices_dtype))
self._sizes = plasma_utils.PlasmaArray(self._sizes.astype(size_dtype))
self._block_to_dataset_index = plasma_utils.PlasmaArray(block_to_dataset_index.astype(slice_indices_dtype))

@property
def slice_indices(self):
Expand Down

0 comments on commit 55e48f1

Please sign in to comment.