Skip to content

Commit

Permalink
Document the data arrangement in word_language_model
Browse files Browse the repository at this point in the history
  • Loading branch information
Nikolai Morin authored and soumith committed Oct 26, 2017
1 parent 23f8abf commit 7532a61
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 2 deletions.
2 changes: 1 addition & 1 deletion time_sequence_prediction/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def forward(self, input, future = 0):


if __name__ == '__main__':
# set ramdom seed to 0
# set random seed to 0
np.random.seed(0)
torch.manual_seed(0)
# load data and make training set
Expand Down
23 changes: 22 additions & 1 deletion word_language_model/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
# coding: utf-8
import argparse
import time
import math
Expand Down Expand Up @@ -57,6 +58,18 @@

corpus = data.Corpus(args.data)

# Starting from sequential data, batchify arranges the dataset into columns.
# For instance, with the alphabet as the sequence and batch size 4, we'd get
# ┌ a g m s ┐
# │ b h n t │
# │ c i o u │
# │ d j p v │
# │ e k q w │
# └ f l r x ┘.
# These columns are treated as independent by the model, which means that the
# dependence of e. g. 'g' on 'f' can not be learned, but allows more efficient
# batch processing.

def batchify(data, bsz):
# Work out how cleanly we can divide the dataset into bsz parts.
nbatch = data.size(0) // bsz
Expand Down Expand Up @@ -95,7 +108,15 @@ def repackage_hidden(h):
else:
return tuple(repackage_hidden(v) for v in h)


# get_batch subdivides the source data into chunks of length args.bptt.
# If source is equal to the example output of the batchify function, with
# a bptt-limit of 2, we'd get the following two Variables for i = 0:
# ┌ a g m s ┐ ┌ b h n t ┐
# └ b h n t ┘ └ c i o u ┘
# Note that despite the name of the function, the subdivison of data is not
# done along the batch dimension (i.e. dimension 1), since that was handled
# by the batchify function. The chunks are along dimension 0, corresponding
# to the seq_len dimension in the LSTM.
def get_batch(source, i, evaluation=False):
seq_len = min(args.bptt, len(source) - 1 - i)
data = Variable(source[i:i+seq_len], volatile=evaluation)
Expand Down

0 comments on commit 7532a61

Please sign in to comment.