Skip to content

Commit 5f0776a

Browse files
authored
Move dataset.map back to before dataset.shuffle in imagenet_main.py (tensorflow#2731)
1 parent 21b48a8 commit 5f0776a

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

official/resnet/imagenet_main.py

+6-8
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,8 @@ def filenames(is_training, data_dir):
8989
for i in range(128)]
9090

9191

92-
def dataset_parser(value, is_training):
93-
"""Parse an Imagenet record from value."""
92+
def record_parser(value, is_training):
93+
"""Parse an ImageNet record from `value`."""
9494
keys_to_features = {
9595
'image/encoded':
9696
tf.FixedLenFeature((), tf.string, default_value=''),
@@ -134,23 +134,21 @@ def dataset_parser(value, is_training):
134134

135135
def input_fn(is_training, data_dir, batch_size, num_epochs=1):
136136
"""Input function which provides batches for train or eval."""
137-
dataset = tf.data.Dataset.from_tensor_slices(
138-
filenames(is_training, data_dir))
137+
dataset = tf.data.Dataset.from_tensor_slices(filenames(is_training, data_dir))
139138

140139
if is_training:
141140
dataset = dataset.shuffle(buffer_size=_FILE_SHUFFLE_BUFFER)
142141

143142
dataset = dataset.flat_map(tf.data.TFRecordDataset)
143+
dataset = dataset.map(lambda value: record_parser(value, is_training),
144+
num_parallel_calls=5)
145+
dataset = dataset.prefetch(batch_size)
144146

145147
if is_training:
146148
# When choosing shuffle buffer sizes, larger sizes result in better
147149
# randomness, while smaller sizes have better performance.
148150
dataset = dataset.shuffle(buffer_size=_SHUFFLE_BUFFER)
149151

150-
dataset = dataset.map(lambda value: dataset_parser(value, is_training),
151-
num_parallel_calls=5)
152-
dataset = dataset.prefetch(batch_size)
153-
154152
# We call repeat after shuffling, rather than before, to prevent separate
155153
# epochs from blending together.
156154
dataset = dataset.repeat(num_epochs)

0 commit comments

Comments
 (0)