Skip to content

Commit

Permalink
Standardize parameter key definition of Example Gen
Browse files Browse the repository at this point in the history
TFX users define each module in pipeline through different Key-value pairs.
The raw key string are used in different places. To avoid typo that may break the pipeline, define the key strings as public variables under types, and reuse the variables in other files.

This diff is a cleanup on example gen module.

PiperOrigin-RevId: 359945429
  • Loading branch information
tfx-copybara committed Feb 27, 2021
1 parent 8ddca2a commit 992e79b
Show file tree
Hide file tree
Showing 15 changed files with 211 additions and 147 deletions.
38 changes: 23 additions & 15 deletions tfx/components/example_gen/base_example_gen_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
from tfx.dsl.components.base import base_executor
from tfx.proto import example_gen_pb2
from tfx.types import artifact_utils
from tfx.types import standard_component_specs
from tfx.utils import proto_utils
from tfx_bsl.telemetry import util

Expand Down Expand Up @@ -212,12 +213,14 @@ def GenerateExamplesByBeam(
"""
# Get input split information.
input_config = example_gen_pb2.Input()
proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY],
input_config)
proto_utils.json_to_proto(
exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
input_config)
# Get output split information.
output_config = example_gen_pb2.Output()
proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY],
output_config)
proto_utils.json_to_proto(
exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY],
output_config)
# Get output split names.
split_names = utils.generate_output_split_names(input_config, output_config)
# Make beam_pipeline_args available in exec_properties since certain
Expand Down Expand Up @@ -295,14 +298,16 @@ def Do(
self._log_startup(input_dict, output_dict, exec_properties)

input_config = example_gen_pb2.Input()
proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY],
input_config)
proto_utils.json_to_proto(
exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
input_config)
output_config = example_gen_pb2.Output()
proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY],
output_config)
proto_utils.json_to_proto(
exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY],
output_config)

examples_artifact = artifact_utils.get_single_instance(
output_dict[utils.EXAMPLES_KEY])
output_dict[standard_component_specs.EXAMPLES_KEY])
examples_artifact.split_names = artifact_utils.encode_split_names(
utils.generate_output_split_names(input_config, output_config))

Expand All @@ -314,13 +319,16 @@ def Do(
for split_name, example_split in example_splits.items():
(example_split
| 'WriteSplit[{}]'.format(split_name) >> _WriteSplit(
artifact_utils.get_split_uri(output_dict[utils.EXAMPLES_KEY],
split_name)))
artifact_utils.get_split_uri(
output_dict[standard_component_specs.EXAMPLES_KEY],
split_name)))
# pylint: enable=expression-not-assigned, no-value-for-parameter

output_payload_format = exec_properties.get(utils.OUTPUT_DATA_FORMAT_KEY)
output_payload_format = exec_properties.get(
standard_component_specs.OUTPUT_DATA_FORMAT_KEY)
if output_payload_format:
for output_examples_artifact in output_dict[utils.EXAMPLES_KEY]:
examples_utils.set_payload_format(
output_examples_artifact, output_payload_format)
for output_examples_artifact in output_dict[
standard_component_specs.EXAMPLES_KEY]:
examples_utils.set_payload_format(output_examples_artifact,
output_payload_format)
logging.info('Examples generated.')
37 changes: 20 additions & 17 deletions tfx/components/example_gen/base_example_gen_executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,17 +19,17 @@
from __future__ import print_function

import os

import apache_beam as beam
from apache_beam.metrics.metric import MetricsFilter
from apache_beam.runners.direct import direct_runner
import tensorflow as tf

from tfx.components.example_gen import base_example_gen_executor
from tfx.components.example_gen import utils
from tfx.dsl.io import fileio
from tfx.proto import example_gen_pb2
from tfx.types import artifact_utils
from tfx.types import standard_artifacts
from tfx.types import standard_component_specs
from tfx.utils import proto_utils


Expand Down Expand Up @@ -94,7 +94,9 @@ def setUp(self):
# Create output dict.
self._examples = standard_artifacts.Examples()
self._examples.uri = self._output_data_dir
self._output_dict = {utils.EXAMPLES_KEY: [self._examples]}
self._output_dict = {
standard_component_specs.EXAMPLES_KEY: [self._examples]
}

self._train_output_file = os.path.join(self._examples.uri, 'train',
'data_tfrecord-00000-of-00001.gz')
Expand All @@ -103,13 +105,13 @@ def setUp(self):

# Create exec proterties for output splits.
self._exec_properties = {
utils.INPUT_CONFIG_KEY:
standard_component_specs.INPUT_CONFIG_KEY:
proto_utils.proto_to_json(
example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(
name='single', pattern='single/*'),
])),
utils.OUTPUT_CONFIG_KEY:
standard_component_specs.OUTPUT_CONFIG_KEY:
proto_utils.proto_to_json(
example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
Expand Down Expand Up @@ -141,14 +143,14 @@ def _testDo(self):
def testDoInputSplit(self):
# Create exec proterties for input split.
self._exec_properties = {
utils.INPUT_CONFIG_KEY:
standard_component_specs.INPUT_CONFIG_KEY:
proto_utils.proto_to_json(
example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(
name='train', pattern='train/*'),
example_gen_pb2.Input.Split(name='eval', pattern='eval/*')
])),
utils.OUTPUT_CONFIG_KEY:
standard_component_specs.OUTPUT_CONFIG_KEY:
proto_utils.proto_to_json(example_gen_pb2.Output())
}

Expand All @@ -170,16 +172,17 @@ def testDoOutputSplitWithSequenceExample(self):
self._testDo()

def _testFeatureBasedPartition(self, partition_feature_name):
self._exec_properties[utils.OUTPUT_CONFIG_KEY] = proto_utils.proto_to_json(
example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(
splits=[
example_gen_pb2.SplitConfig.Split(
name='train', hash_buckets=2),
example_gen_pb2.SplitConfig.Split(
name='eval', hash_buckets=1)
],
partition_feature_name=partition_feature_name)))
self._exec_properties[
standard_component_specs.OUTPUT_CONFIG_KEY] = proto_utils.proto_to_json(
example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(
splits=[
example_gen_pb2.SplitConfig.Split(
name='train', hash_buckets=2),
example_gen_pb2.SplitConfig.Split(
name='eval', hash_buckets=1)
],
partition_feature_name=partition_feature_name)))

def testFeatureBasedPartition(self):
# Update exec proterties.
Expand Down
63 changes: 40 additions & 23 deletions tfx/components/example_gen/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
from tfx.proto import example_gen_pb2
from tfx.proto import range_config_pb2
from tfx.types import standard_artifacts
from tfx.types import standard_component_specs
from tfx.utils import proto_utils

from google.protobuf import any_pb2
Expand Down Expand Up @@ -83,11 +84,16 @@ def testConstructSubclassQueryBased(self):
]))
self.assertEqual({}, example_gen.inputs.get_all())
self.assertEqual(base_driver.BaseDriver, example_gen.driver_class)
self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs['examples'].type_name)
self.assertEqual(example_gen.exec_properties['output_data_format'],
example_gen_pb2.FORMAT_TF_EXAMPLE)
self.assertIsNone(example_gen.exec_properties.get('custom_config'))
self.assertEqual(
standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)
self.assertEqual(
example_gen.exec_properties[
standard_component_specs.OUTPUT_DATA_FORMAT_KEY],
example_gen_pb2.FORMAT_TF_EXAMPLE)
self.assertIsNone(
example_gen.exec_properties.get(
standard_component_specs.CUSTOM_CONFIG_KEY))

def testConstructSubclassQueryBasedWithInvalidOutputDataFormat(self):
self.assertRaises(
Expand All @@ -101,20 +107,25 @@ def testConstructSubclassQueryBasedWithInvalidOutputDataFormat(self):

def testConstructSubclassFileBased(self):
example_gen = TestFileBasedExampleGenComponent(input_base='path')
self.assertIn('input_base', example_gen.exec_properties)
self.assertIn(standard_component_specs.INPUT_BASE_KEY,
example_gen.exec_properties)
self.assertEqual(driver.Driver, example_gen.driver_class)
self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs['examples'].type_name)
self.assertIsNone(example_gen.exec_properties.get('custom_config'))
self.assertEqual(
standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)
self.assertIsNone(
example_gen.exec_properties.get(
standard_component_specs.CUSTOM_CONFIG_KEY))

def testConstructCustomExecutor(self):
example_gen = component.FileBasedExampleGen(
input_base='path',
custom_executor_spec=executor_spec.ExecutorClassSpec(
TestExampleGenExecutor))
self.assertEqual(driver.Driver, example_gen.driver_class)
self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs['examples'].type_name)
self.assertEqual(
standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)

def testConstructWithOutputConfig(self):
output_config = example_gen_pb2.Output(
Expand All @@ -125,12 +136,14 @@ def testConstructWithOutputConfig(self):
]))
example_gen = TestFileBasedExampleGenComponent(
input_base='path', output_config=output_config)
self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs['examples'].type_name)
self.assertEqual(
standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)

stored_output_config = example_gen_pb2.Output()
proto_utils.json_to_proto(example_gen.exec_properties['output_config'],
stored_output_config)
proto_utils.json_to_proto(
example_gen.exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY],
stored_output_config)
self.assertEqual(output_config, stored_output_config)

def testConstructWithInputConfig(self):
Expand All @@ -141,12 +154,14 @@ def testConstructWithInputConfig(self):
])
example_gen = TestFileBasedExampleGenComponent(
input_base='path', input_config=input_config)
self.assertEqual(standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs['examples'].type_name)
self.assertEqual(
standard_artifacts.Examples.TYPE_NAME,
example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name)

stored_input_config = example_gen_pb2.Input()
proto_utils.json_to_proto(example_gen.exec_properties['input_config'],
stored_input_config)
proto_utils.json_to_proto(
example_gen.exec_properties[standard_component_specs.INPUT_CONFIG_KEY],
stored_input_config)
self.assertEqual(input_config, stored_input_config)

def testConstructWithCustomConfig(self):
Expand All @@ -158,8 +173,9 @@ def testConstructWithCustomConfig(self):
TestExampleGenExecutor))

stored_custom_config = example_gen_pb2.CustomConfig()
proto_utils.json_to_proto(example_gen.exec_properties['custom_config'],
stored_custom_config)
proto_utils.json_to_proto(
example_gen.exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY],
stored_custom_config)
self.assertEqual(custom_config, stored_custom_config)

def testConstructWithStaticRangeConfig(self):
Expand All @@ -172,8 +188,9 @@ def testConstructWithStaticRangeConfig(self):
custom_executor_spec=executor_spec.ExecutorClassSpec(
TestExampleGenExecutor))
stored_range_config = range_config_pb2.RangeConfig()
proto_utils.json_to_proto(example_gen.exec_properties['range_config'],
stored_range_config)
proto_utils.json_to_proto(
example_gen.exec_properties[standard_component_specs.RANGE_CONFIG_KEY],
stored_range_config)
self.assertEqual(range_config, stored_range_config)


Expand Down
5 changes: 2 additions & 3 deletions tfx/components/example_gen/csv_example_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,9 @@
from absl import logging
import apache_beam as beam
import tensorflow as tf

from tfx.components.example_gen import utils
from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor
from tfx.dsl.io import fileio
from tfx.types import standard_component_specs
from tfx.utils import io_utils
from tfx_bsl.coders import csv_decoder

Expand Down Expand Up @@ -122,7 +121,7 @@ def _CsvToExample( # pylint: disable=invalid-name
Raises:
RuntimeError: if split is empty or csv headers are not equal.
"""
input_base_uri = exec_properties[utils.INPUT_BASE_KEY]
input_base_uri = exec_properties[standard_component_specs.INPUT_BASE_KEY]
csv_pattern = os.path.join(input_base_uri, split_pattern)
logging.info('Processing input csv data %s to TFExample.', csv_pattern)

Expand Down
20 changes: 12 additions & 8 deletions tfx/components/example_gen/csv_example_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,16 +19,16 @@
from __future__ import print_function

import os

import apache_beam as beam
from apache_beam.testing import util
import tensorflow as tf

from tfx.components.example_gen import utils
from tfx.components.example_gen.csv_example_gen import executor
from tfx.dsl.io import fileio
from tfx.proto import example_gen_pb2
from tfx.types import artifact_utils
from tfx.types import standard_artifacts
from tfx.types import standard_component_specs
from tfx.utils import proto_utils


Expand All @@ -45,7 +45,9 @@ def testCsvToExample(self):
examples = (
pipeline
| 'ToTFExample' >> executor._CsvToExample(
exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir},
exec_properties={
standard_component_specs.INPUT_BASE_KEY: self._input_data_dir
},
split_pattern='csv/*'))

def check_results(results):
Expand All @@ -61,7 +63,9 @@ def testCsvToExampleWithEmptyColumn(self):
examples = (
pipeline
| 'ToTFExample' >> executor._CsvToExample(
exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir},
exec_properties={
standard_component_specs.INPUT_BASE_KEY: self._input_data_dir
},
split_pattern='csv_empty/*'))

def check_results(results):
Expand All @@ -88,18 +92,18 @@ def testDo(self):
# Create output dict.
examples = standard_artifacts.Examples()
examples.uri = output_data_dir
output_dict = {utils.EXAMPLES_KEY: [examples]}
output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]}

# Create exec proterties.
exec_properties = {
utils.INPUT_BASE_KEY:
standard_component_specs.INPUT_BASE_KEY:
self._input_data_dir,
utils.INPUT_CONFIG_KEY:
standard_component_specs.INPUT_CONFIG_KEY:
proto_utils.proto_to_json(
example_gen_pb2.Input(splits=[
example_gen_pb2.Input.Split(name='csv', pattern='csv/*'),
])),
utils.OUTPUT_CONFIG_KEY:
standard_component_specs.OUTPUT_CONFIG_KEY:
proto_utils.proto_to_json(
example_gen_pb2.Output(
split_config=example_gen_pb2.SplitConfig(splits=[
Expand Down
3 changes: 2 additions & 1 deletion tfx/components/example_gen/custom_executors/avro_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@

from tfx.components.example_gen import utils
from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor
from tfx.types import standard_component_specs


@beam.ptransform_fn
Expand All @@ -49,7 +50,7 @@ def _AvroToExample( # pylint: disable=invalid-name
Returns:
PCollection of TF examples.
"""
input_base_uri = exec_properties[utils.INPUT_BASE_KEY]
input_base_uri = exec_properties[standard_component_specs.INPUT_BASE_KEY]
avro_pattern = os.path.join(input_base_uri, split_pattern)
logging.info('Processing input avro data %s to TFExample.', avro_pattern)

Expand Down
Loading

0 comments on commit 992e79b

Please sign in to comment.