Skip to content

Commit

Permalink
Update interleave hyperparameters
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 265780130
  • Loading branch information
rachellj218 authored and tensorflower-gardener committed Aug 27, 2019
1 parent 0fa5ff2 commit 85956b1
Showing 1 changed file with 5 additions and 1 deletion.
6 changes: 5 additions & 1 deletion official/bert/input_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,8 +94,12 @@ def create_pretrain_dataset(file_paths,
dataset = dataset.shuffle(len(file_paths))

# In parallel, create tf record dataset for each train files.
# cycle_length = 8 means that up to 8 files will be read and deserialized in
# parallel. You may want to increase this number if you have a large number of
# CPU cores.
dataset = dataset.interleave(
tf.data.TFRecordDataset, cycle_length=tf.data.experimental.AUTOTUNE)
tf.data.TFRecordDataset, cycle_length=8,
num_parallel_calls=tf.data.experimental.AUTOTUNE)

decode_fn = lambda record: decode_record(record, name_to_features)
dataset = dataset.map(
Expand Down

0 comments on commit 85956b1

Please sign in to comment.