Skip to content

Commit

Permalink
Merge branch 'remove_variant_type' into 'master'
Browse files Browse the repository at this point in the history
remove variant_type

See merge request data/monolith!2117

GitOrigin-RevId: 75f862b49983eb0209d22f162ad46ee6f21d947c
  • Loading branch information
zhangpiu authored and monolith committed Sep 4, 2023
1 parent fd7b096 commit f871aa6
Showing 1 changed file with 35 additions and 9 deletions.
44 changes: 35 additions & 9 deletions monolith/native_training/data/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@
flags.DEFINE_string('kafka_other_metadata', None,
'string, kafka_other_metadata')
POOL_KEY = "TF_ITEMPOOL"
OUTPUT_PB_TYPE_GRAPH_KEY = "monolith_dataset_output_pb_type"


class FeaturePruningType(object):
Expand Down Expand Up @@ -393,6 +394,8 @@ def __init__(self,

self._out_type = tf.string if output_pb_type == PbType.PLAINTEXT else tf.variant

tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY,
value=output_pb_type.to_name())
variant_tensor = pb_datasource_ops.parquet_dataset(
file_name=file_name,
output_pb_type=output_pb_type.to_name(),
Expand Down Expand Up @@ -507,6 +510,8 @@ def __init__(
if use_snappy is None:
use_snappy = False

tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY,
value=output_pb_type.to_name())
variant_tensor = pb_datasource_ops.pb_dataset(
file_name=file_name,
use_snappy=use_snappy,
Expand Down Expand Up @@ -781,7 +786,11 @@ def element_spec(self):
def instance_reweight(self,
action_priority: str,
reweight: str,
variant_type: str = 'example'):
**kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
assert variant_type in {"instance", "example"}
return InstanceReweightDataset(self,
action_priority,
reweight,
Expand Down Expand Up @@ -879,6 +888,7 @@ def __init__(self,
'ds_to_merge_{}'.format(i + 1)
for i in range(len(self._dataset_to_merge))
]

variant_tensor = pb_datasource_ops.merge_flow_dataset(
input_dataset_variant,
data_flow=data_flow,
Expand Down Expand Up @@ -918,7 +928,11 @@ def negative_gen(self,
origin_neg_in_pool_proba: float = 1.0,
neg_sample_declay_factor: float = 1.0,
easy_hard_ratio: float = 0.0,
variant_type: str = 'example'):
**kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
assert variant_type in {"instance", "example"}
return NegativeGenDataset(
self,
neg_num=neg_num,
Expand Down Expand Up @@ -951,7 +965,11 @@ def split_flow(self,
data_flow: List[str],
index: int,
max_queue_size: int = 1024,
variant_type: str = 'example'):
**kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
assert variant_type in {"instance", "example"}
return SplitFlowDataset(self,
data_flow=data_flow,
index=index,
Expand All @@ -962,7 +980,11 @@ def split_flow(self,
def merge_flow(self,
dataset_to_merge,
max_queue_size: int = 1024,
variant_type: str = 'example'):
**kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
assert variant_type in {"instance", "example"}
return MergeFlowDataset(self,
dataset_to_merge,
max_queue_size=max_queue_size,
Expand Down Expand Up @@ -1221,6 +1243,7 @@ def __init__(self,
for meta in kafka_other_metadata_list:
metadata.append(meta)

tf.compat.v1.add_to_collection(name=OUTPUT_PB_TYPE_GRAPH_KEY, value=output_pb_type)
resource = kafka_resource_init(
topics=topics,
metadata=metadata,
Expand Down Expand Up @@ -1295,9 +1318,10 @@ def register_dataset(service, dataset, buffer_size=32):
if external_state_policy is None:
external_state_policy = ExternalStatePolicy.WARN
logging.info('external_state_policy: %s', external_state_policy)
dataset = dataset.map(lambda *x: compression_ops.compress(x),
# num_parallel_calls=dataset_ops.AUTOTUNE)
num_parallel_calls=None)
dataset = dataset.map(
lambda *x: compression_ops.compress(x),
# num_parallel_calls=dataset_ops.AUTOTUNE)
num_parallel_calls=None)
logging.info('num_parallel_calls: None')
# dataset = dataset.prefetch(buffer_size=buffer_size)
dataset = dataset._apply_options()
Expand Down Expand Up @@ -1521,7 +1545,10 @@ def distribute(self,
return dataset


def transform(self, t: Transform, variant_type: str):
def transform(self, t: Transform, **kwargs):
value = tf.compat.v1.get_collection(OUTPUT_PB_TYPE_GRAPH_KEY)
assert len(value) == 1
variant_type = value[0]
assert variant_type in {"instance", "example"}
return TransformDataset(self, t, variant_type=variant_type)

Expand All @@ -1544,7 +1571,6 @@ class TransformDataset(dataset_ops.UnaryUnchangedStructureDataset):
def __init__(self, input_dataset, transform: Transform, variant_type: str):
assert variant_type in {"instance", "example"}
self._transform = transform
self._variant_type = variant_type

variant_tensor = pb_datasource_ops.transform_dataset(
input=input_dataset._variant_tensor,
Expand Down

0 comments on commit f871aa6

Please sign in to comment.