Skip to content

Commit c6479e7

Browse files
committed
Ensure that shuffle occurs before map
1 parent 6e52c27 commit c6479e7

File tree

3 files changed

+10
-8
lines changed

3 files changed

+10
-8
lines changed

official/resnet/cifar10_main.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -71,8 +71,6 @@
7171
'validation': 10000,
7272
}
7373

74-
_SHUFFLE_BUFFER = 20000
75-
7674

7775
def record_dataset(filenames):
7876
"""Returns an input pipeline Dataset from `filenames`."""
@@ -158,8 +156,9 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
158156

159157
if is_training:
160158
# When choosing shuffle buffer sizes, larger sizes result in better
161-
# randomness, while smaller sizes have better performance.
162-
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
159+
# randomness, while smaller sizes have better performance. Because CIFAR-10
160+
# is a relatively small dataset, we choose to shuffle the full epoch.
161+
dataset = dataset.shuffle(buffer_size=_NUM_IMAGES['train'])
163162

164163
dataset = dataset.map(parse_record)
165164
dataset = dataset.map(

official/resnet/imagenet_main.py

+4-3
Original file line numberDiff line numberDiff line change
@@ -142,14 +142,15 @@ def input_fn(is_training, data_dir, batch_size, num_epochs=1):
142142

143143
dataset = dataset.flat_map(tf.data.TFRecordDataset)
144144

145-
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
146-
num_parallel_calls=5).prefetch(batch_size)
147-
148145
if is_training:
149146
# When choosing shuffle buffer sizes, larger sizes result in better
150147
# randomness, while smaller sizes have better performance.
151148
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
152149

150+
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
151+
num_parallel_calls=5)
152+
dataset = dataset.prefetch(batch_size)
153+
153154
# We call repeat after shuffling, rather than before, to prevent separate
154155
# epochs from blending together.
155156
dataset = dataset.repeat(num_epochs)

official/wide_deep/wide_deep.py

+3-1
Original file line numberDiff line numberDiff line change
@@ -179,11 +179,12 @@ def parse_csv(value):
179179

180180
# Extract lines from input files using the Dataset API.
181181
dataset = tf.data.TextLineDataset(data_file)
182-
dataset = dataset.map(parse_csv, num_parallel_calls=5)
183182

184183
if shuffle:
185184
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
186185

186+
dataset = dataset.map(parse_csv, num_parallel_calls=5)
187+
187188
# We call repeat after shuffling, rather than before, to prevent separate
188189
# epochs from blending together.
189190
dataset = dataset.repeat(num_epochs)
@@ -193,6 +194,7 @@ def parse_csv(value):
193194
features, labels = iterator.get_next()
194195
return features, labels
195196

197+
196198
def main(unused_argv):
197199
# Clean up the model directory if present
198200
shutil.rmtree(FLAGS.model_dir, ignore_errors=True)

0 commit comments

Comments
 (0)