Skip to content

Commit

Permalink
fix np.memmap memory leak
Browse files Browse the repository at this point in the history
nn.memmap doesn't free memory that it accesses. Thus, the entire dataset gets stored in RAM as the dataset has been fully accessed. The simplest workaround on stackoverflow is to just recreate the memmap for each batch. The extra overhead is negligible.

https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
  • Loading branch information
kjslag authored Jan 25, 2024
1 parent eba36e8 commit 5156fef
Showing 1 changed file with 6 additions and 3 deletions.
9 changes: 6 additions & 3 deletions train.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,10 +113,13 @@

# poor man's data loader
data_dir = os.path.join('data', dataset)
train_data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
val_data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
def get_batch(split):
data = train_data if split == 'train' else val_data
# We recreate np.memmap every batch to avoid a memory leak, as per
# https://stackoverflow.com/questions/45132940/numpy-memmap-memory-usage-want-to-iterate-once/61472122#61472122
if split == 'train':
data = np.memmap(os.path.join(data_dir, 'train.bin'), dtype=np.uint16, mode='r')
else:
data = np.memmap(os.path.join(data_dir, 'val.bin'), dtype=np.uint16, mode='r')
ix = torch.randint(len(data) - block_size, (batch_size,))
x = torch.stack([torch.from_numpy((data[i:i+block_size]).astype(np.int64)) for i in ix])
y = torch.stack([torch.from_numpy((data[i+1:i+1+block_size]).astype(np.int64)) for i in ix])
Expand Down

0 comments on commit 5156fef

Please sign in to comment.