Skip to content

Commit

Permalink
Merge branch 'support_tfrecord_dataset' into 'master'
Browse files Browse the repository at this point in the history
support tfrecorddataset

See merge request data/monolith!2141

GitOrigin-RevId: 143074ce7a54b1227110d5c024ff409857ae61c7
  • Loading branch information
zhangpiu authored and monolith committed Nov 7, 2023
1 parent fe6f9b0 commit 412a48a
Showing 1 changed file with 67 additions and 18 deletions.
85 changes: 67 additions & 18 deletions monolith/native_training/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,8 @@
'string, dataset_input_compression_type')
flags.DEFINE_bool('dataset_input_use_parquet', None,
'bool dataset_input_use_parquet')
flags.DEFINE_bool('dataset_input_use_tfrecord', None,
'bool dataset_input_use_tfrecord')
flags.DEFINE_integer('dataset_worker_idx', None, 'int dataset_worker_idx')
flags.DEFINE_integer('dataset_num_workers', None, 'int dataset_num_workers')
flags.DEFINE_string('kafka_other_metadata', None,
Expand Down Expand Up @@ -124,6 +126,8 @@ def _get_params(name, default=None):
class DatasetMetaclass(type):

def __call__(cls, *args, **kwargs):
logging.info('---args: %s', args)
logging.info('---kwargs: %s', kwargs)
if kwargs.get('topics_or_files', None):
value = kwargs['topics_or_files']
if isinstance(value, str):
Expand Down Expand Up @@ -221,42 +225,71 @@ def pattern_recurse(pattern_format_list, *args):

if FLAGS.dataset_input_use_parquet is not None:
kwargs['use_parquet'] = FLAGS.dataset_input_use_parquet
if FLAGS.dataset_input_use_tfrecord is not None:
kwargs['use_tfrecord'] = FLAGS.dataset_input_use_tfrecord
assert not (
FLAGS.dataset_input_use_parquet and FLAGS.dataset_input_use_tfrecord
), "It's not allowed to specify dataset_input_use_parquet=True and dataset_input_use_tfrecord=True"
if kwargs.get('kafka_other_metadata',
None) is None and FLAGS.kafka_other_metadata is not None:
kwargs['kafka_other_metadata'] = FLAGS.kafka_other_metadata
try:
# the first param is str, batch to streaming, use kafka params for cmd
args = [
kafka_args = [
kwargs.pop('topics', FLAGS.kafka_topics.split(',')),
kwargs.pop('group_id', FLAGS.kafka_group_id),
kwargs.pop('servers', FLAGS.kafka_servers)
]
assert all(x is not None for x in args)
assert all(x is not None for x in kafka_args)
logging.info('use KafkaDataset!')
return KafkaDataset(*args, **kwargs)
except:
return KafkaDataset(*kafka_args, **kwargs)
except Exception as e:
logging.error(str(e))
logging.info("it's not streaming training")

if args is None or len(args) == 0:
if 'patterns' in kwargs and 'group_id' not in kwargs and 'servers' not in kwargs:
tf_record_args = {
'file_name', 'compression_type', 'buffer_size', 'num_parallel_reads'
}

def is_kafka_dataset():
# 'topics', 'group_id' and 'servers' are for KafkaDataset
return 'topics' in kwargs and 'group_id' in kwargs and 'servers' in kwargs

if args is None or len(args) == 0: # all arguments are in kwargs
# 'patterns' for DistributedFilePBDataset
if 'patterns' in kwargs and not is_kafka_dataset():
logging.info('use DistributedFilePBDataset!')
return DistributedFilePBDataset(**kwargs)
elif 'topics' in kwargs and 'group_id' in kwargs and 'servers' in kwargs:
elif is_kafka_dataset():
logging.info('use KafkaDataset!')
return KafkaDataset(**kwargs)
elif kwargs.get('use_parquet'):
return ParquetDataset(**kwargs)
elif kwargs.get('use_tfrecord'):
logging.info('use TFRecordDataset!')
invalid_args = list(k for k in kwargs if k not in tf_record_args)
for k in invalid_args:
kwargs.pop(k)
logging.info('---kwargs: %s', kwargs)
return TFRecordDatasetWrapper(**kwargs)
elif 'file_name' in kwargs or len(kwargs) == 0:
return FilePBDataset(*args, **kwargs)
else:
return super(DatasetMetaclass, cls).__call__(*args, **kwargs)
elif isinstance(args[0], str):
elif isinstance(args[0], str): # The first arg is a filename
if kwargs.get('use_parquet'):
return ParquetDataset(*args, **kwargs)
elif kwargs.get('use_tfrecord'):
logging.info('use TFRecordDataset!')
invalid_args = list(k for k in kwargs if k not in tf_record_args)
for k in invalid_args:
kwargs.pop(k)
logging.info('---kwargs: %s', kwargs)
return TFRecordDatasetWrapper(*args, **kwargs)
else:
logging.info('use FilePBDataset!')
return FilePBDataset(*args, **kwargs)
elif isinstance(args[0], (list, tuple)):
elif isinstance(args[0], (list, tuple)): # The first arg is a list, never reach here
if len(args) > 1:
if isinstance(args[1], str):
logging.info('use KafkaDataset!')
Expand Down Expand Up @@ -365,6 +398,20 @@ def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.string)


class TFRecordDatasetWrapper(tf.data.TFRecordDataset):

def __init__(self,
file_name,
compression_type=None,
buffer_size=None,
num_parallel_reads=None,
**kwargs):
super().__init__(file_name,
compression_type=compression_type,
buffer_size=buffer_size,
num_parallel_reads=num_parallel_reads)


class ParquetDataset(dataset_ops.DatasetSource):

def __init__(self,
Expand Down Expand Up @@ -551,6 +598,7 @@ def __init__(
num_parallel_calls=tf.data.AUTOTUNE,
deterministic=None,
use_parquet: bool = False,
use_tfrecord: bool = False,
**kwargs):
if not patterns:
patterns = [""]
Expand All @@ -566,9 +614,15 @@ def __init__(
False))
logging.info(f"enable_dynamic_sharding: {enable_dynamic_sharding}")

assert not (
use_parquet and use_tfrecord
), "It's not allowed to specify use_parquet=True and use_tfrecord=True simultaneously!"
if use_parquet:
map_func = lambda file_name: ParquetDataset(
file_name=file_name, output_pb_type=output_pb_type, **kwargs)
elif use_tfrecord:
map_func = lambda file_name: tf.data.TFRecordDataset(filenames=
[file_name])
else:
map_func = lambda file_name: FilePBDataset(
file_name=file_name,
Expand Down Expand Up @@ -783,10 +837,7 @@ def element_spec(self):
return tensor_spec.TensorSpec([], dtypes.variant)


def instance_reweight(self,
action_priority: str,
reweight: str,
**kwargs):
def instance_reweight(self, action_priority: str, reweight: str, **kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
Expand Down Expand Up @@ -977,10 +1028,7 @@ def split_flow(self,
variant_type=variant_type)


def merge_flow(self,
dataset_to_merge,
max_queue_size: int = 1024,
**kwargs):
def merge_flow(self, dataset_to_merge, max_queue_size: int = 1024, **kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
Expand Down Expand Up @@ -1243,7 +1291,8 @@ def __init__(self,
for meta in kafka_other_metadata_list:
metadata.append(meta)

tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY, value=output_pb_type)
tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY,
value=output_pb_type)
resource = kafka_resource_init(
topics=topics,
metadata=metadata,
Expand Down

0 comments on commit 412a48a

Please sign in to comment.