This repository has been archived by the owner on Jan 10, 2025. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 71
/
Copy pathdataset.py
291 lines (237 loc) · 9.74 KB
/
dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
# coding=utf-8
# Copyright 2022 The Pix2Seq Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Dataset base class."""
import abc
import functools
import operator
from typing import Callable
import ml_collections
import registry
import tensorflow as tf
import tensorflow_datasets as tfds
DatasetRegistry = registry.Registry()
def mix_datasets(input_fns, weights):
"""Mix multiple datasets according to weights.
Args:
input_fns: a list of input_fn's. Each input_fn takes in an input_context and
produces a tf.data.Dataset instance.
weights: a list of floats where weights[i] represents the probability to
sample from input_fns[i].
Returns:
a tf.data.Dataset instance.
"""
def input_fn(input_context):
dses = []
for ifn in input_fns:
dses.append(ifn(input_context))
mixed_ds = tf.data.Dataset.sample_from_datasets(dses, weights)
return mixed_ds
return tf.distribute.get_strategy().distribute_datasets_from_function(
input_fn)
class Dataset(abc.ABC):
"""A dataset that handles creating a tf.data.Dataset."""
def __init__(self, config: ml_collections.ConfigDict):
"""Constructs the dataset."""
self.config = config.dataset
self.task_config = config.task
@abc.abstractmethod
def extract(self, example, training):
"""Extracts needed features & annotations into a flat dictionary.
Note: be consisous about 0 in label, which should probably reserved for
special use (such as padding).
Args:
example: `dict` of raw features.
training: `bool` of training vs eval mode.
Returns:
example: `dict` of relevant features and labels
"""
@abc.abstractmethod
def load_dataset(self, input_context, training):
"""Load tf.data.Dataset from sources such as TFDS or TFRecord files."""
def parse_example(self, example, training):
del training
return example
def filter_example(self, unused_example, unused_training):
return True
def pipeline(self,
process_single_example: Callable[[tf.data.Dataset, int, bool],
tf.data.Dataset],
global_batch_size: int, training: bool):
"""Data pipeline from name to preprocessed examples.
Args:
process_single_example: a function that takes single example dataset and
returns processed example dataset.
global_batch_size: global batch size.
training: training vs eval mode.
Returns:
An input_fn which generates a tf.data.Dataset instance.
"""
config = self.config
def input_fn(input_context):
dataset = self.load_dataset(input_context, training)
if config.cache_dataset:
dataset = dataset.cache()
if input_context:
batch_size = input_context.get_per_replica_batch_size(global_batch_size)
# Sharding is not neccesary for TFDS given read_config above.
# dataset = dataset.shard(input_context.num_input_pipelines,
# input_context.input_pipeline_id)
else:
batch_size = global_batch_size
if training:
options = tf.data.Options()
options.deterministic = False
options.experimental_slack = True
dataset = dataset.with_options(options)
buffer_size = config.get('buffer_size', 0)
if buffer_size <= 0:
buffer_size = 10 * batch_size
dataset = dataset.shuffle(buffer_size)
dataset = dataset.repeat()
dataset = dataset.map(
lambda x: self.parse_example(x, training),
num_parallel_calls=tf.data.experimental.AUTOTUNE
).filter(
lambda x: self.filter_example(x, training)
).map(
lambda x: self.extract(x, training),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
if process_single_example:
dataset = process_single_example(
dataset, config.batch_duplicates, training)
# TODO(b/181662974): Revert this and support non-even batch sizes.
# dataset = dataset.batch(batch_size, drop_remainder=training)
dataset = dataset.padded_batch(batch_size, drop_remainder=True)
if config.batch_duplicates > 1 and training:
dataset = dataset.map(self._flatten_dims,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
dataset = dataset.prefetch(tf.data.experimental.AUTOTUNE)
return dataset
return input_fn
def _flatten_dims(self, example):
"""Flatten first 2 dims when batch is independently duplicated."""
def flatten_first_2_dims(t):
"""Merge first 2 dims."""
shape_list = t.shape.as_list()
new_bsz = functools.reduce(operator.mul, shape_list[:2])
out_shape = [new_bsz] + shape_list[2:]
return tf.reshape(t, out_shape)
return tf.nest.map_structure(flatten_first_2_dims, example)
@property
@abc.abstractmethod
def num_train_examples(self):
"""Number of training examples."""
@property
@abc.abstractmethod
def num_eval_examples(self):
"""Number of eval examples."""
class TFDSDataset(Dataset):
"""A dataset created from a TFDS dataset.
Each example is a dictionary, but the fields may be different for each
dataset.
Each task would have a list of required fields (e.g. bounding boxes for
object detection). When a dataset is used for a specific task, it should
contain all the fields required by that task.
"""
def __init__(self, config: ml_collections.ConfigDict):
"""Constructs the dataset."""
super().__init__(config)
self.builder = tfds.builder(self.config.tfds_name,
data_dir=self.config.get('data_dir', None))
self.builder.download_and_prepare()
self.allowed_tasks = []
def load_dataset(self, input_context, training):
"""Load tf.data.Dataset from TFDS."""
split = self.config.train_split if training else self.config.eval_split
# For TFDS, pass input_context using read_config to make TFDS read
# different parts of the dataset on different workers.
read_config = tfds.ReadConfig(input_context=input_context)
if isinstance(split, list):
dataset = self.builder.as_dataset(
split=split[0], shuffle_files=training, read_config=read_config)
for i in range(1, len(split)):
dataset.concatenate(self.builder.as_dataset(
split=split[i], shuffle_files=training, read_config=read_config))
else:
dataset = self.builder.as_dataset(
split=split, shuffle_files=training, read_config=read_config)
return dataset
@property
def num_train_examples(self):
return self.builder.info.splits[self.config.train_split].num_examples
@property
def num_eval_examples(self):
return self.builder.info.splits[
self.config.eval_split].num_examples if not self.task_config.get(
'unbatch', False) else None
class TFRecordDataset(Dataset):
"""A dataset created from tfrecord files."""
def __init__(self, config: ml_collections.ConfigDict):
"""Constructs the dataset."""
super().__init__(config)
self.dataset_cls = tf.data.TFRecordDataset
def load_dataset(self, input_context, training):
"""Load tf.data.Dataset from TFRecord files."""
if training or self.config.eval_split == 'train':
file_pattern = self.config.train_file_pattern
else:
file_pattern = self.config.val_file_pattern
dataset = tf.data.Dataset.list_files(file_pattern, shuffle=training)
dataset = dataset.interleave(
self.dataset_cls, cycle_length=32, deterministic=not training,
num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
@abc.abstractmethod
def get_feature_map(self, training):
"""Returns feature map(s) for parsing the TFExample.
Returns a single feature map (a dict) to parse a TFEXample.
Returns a tuple of (context feature map, sequence feature map) to parse a
TFSequenceExample. Context features are non-sequence features, i.e.
independent of time/frame. Sequence features have time/frame dimension.
Args:
training: `bool` of training vs eval mode.
"""
def parse_example(self, example, training):
"""Parse the serialized example into a dictionary of tensors.
Args:
example: the serialized tf.train.Example or tf.train.SequenceExample.
training: `bool` of training vs eval mode.
Returns:
a dictionary of feature name to tensors.
"""
feature_map = self.get_feature_map(training)
if isinstance(feature_map, dict):
example = tf.io.parse_single_example(example, feature_map)
else:
context_features, sequence_features = feature_map
example, sequence = tf.io.parse_single_sequence_example(
example, context_features, sequence_features)
example.update(sequence)
for k in example:
if isinstance(example[k], tf.SparseTensor):
if example[k].dtype == tf.string:
example[k] = tf.sparse.to_dense(example[k], default_value='')
else:
example[k] = tf.sparse.to_dense(example[k], default_value=0)
return example
@property
def num_train_examples(self):
return self.config.train_num_examples
@property
def num_eval_examples(self):
return self.config.eval_num_examples if not self.task_config.get(
'unbatch', False) else None