@@ -89,8 +89,8 @@ def filenames(is_training, data_dir):
89
89
for i in range (128 )]
90
90
91
91
92
- def dataset_parser (value , is_training ):
93
- """Parse an Imagenet record from value."""
92
+ def record_parser (value , is_training ):
93
+ """Parse an ImageNet record from ` value` ."""
94
94
keys_to_features = {
95
95
'image/encoded' :
96
96
tf .FixedLenFeature ((), tf .string , default_value = '' ),
@@ -134,23 +134,21 @@ def dataset_parser(value, is_training):
134
134
135
135
def input_fn (is_training , data_dir , batch_size , num_epochs = 1 ):
136
136
"""Input function which provides batches for train or eval."""
137
- dataset = tf .data .Dataset .from_tensor_slices (
138
- filenames (is_training , data_dir ))
137
+ dataset = tf .data .Dataset .from_tensor_slices (filenames (is_training , data_dir ))
139
138
140
139
if is_training :
141
140
dataset = dataset .shuffle (buffer_size = _FILE_SHUFFLE_BUFFER )
142
141
143
142
dataset = dataset .flat_map (tf .data .TFRecordDataset )
143
+ dataset = dataset .map (lambda value : record_parser (value , is_training ),
144
+ num_parallel_calls = 5 )
145
+ dataset = dataset .prefetch (batch_size )
144
146
145
147
if is_training :
146
148
# When choosing shuffle buffer sizes, larger sizes result in better
147
149
# randomness, while smaller sizes have better performance.
148
150
dataset = dataset .shuffle (buffer_size = _SHUFFLE_BUFFER )
149
151
150
- dataset = dataset .map (lambda value : dataset_parser (value , is_training ),
151
- num_parallel_calls = 5 )
152
- dataset = dataset .prefetch (batch_size )
153
-
154
152
# We call repeat after shuffling, rather than before, to prevent separate
155
153
# epochs from blending together.
156
154
dataset = dataset .repeat (num_epochs )
0 commit comments