From 992e79b1232d1ae9f494fcf6217d6424bf1fad36 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Sat, 27 Feb 2021 09:22:46 -0800 Subject: [PATCH] Standardize parameter key definition of Example Gen TFX users define each module in pipeline through different Key-value pairs. The raw key string are used in different places. To avoid typo that may break the pipeline, define the key strings as public variables under types, and reuse the variables in other files. This diff is a cleanup on example gen module. PiperOrigin-RevId: 359945429 --- .../example_gen/base_example_gen_executor.py | 38 ++++++---- .../base_example_gen_executor_test.py | 37 +++++----- tfx/components/example_gen/component_test.py | 63 ++++++++++------ .../example_gen/csv_example_gen/executor.py | 5 +- .../csv_example_gen/executor_test.py | 20 +++-- .../custom_executors/avro_executor.py | 3 +- .../custom_executors/avro_executor_test.py | 14 ++-- .../custom_executors/parquet_executor.py | 4 +- .../custom_executors/parquet_executor_test.py | 14 ++-- tfx/components/example_gen/driver.py | 28 ++++--- tfx/components/example_gen/driver_test.py | 73 +++++++++++-------- .../import_example_gen/executor.py | 7 +- .../import_example_gen/executor_test.py | 17 +++-- .../container/kubeflow_v2_entrypoint_utils.py | 4 +- tfx/types/standard_component_specs.py | 31 +++++--- 15 files changed, 211 insertions(+), 147 deletions(-) diff --git a/tfx/components/example_gen/base_example_gen_executor.py b/tfx/components/example_gen/base_example_gen_executor.py index effae4e518..70724e5ab5 100644 --- a/tfx/components/example_gen/base_example_gen_executor.py +++ b/tfx/components/example_gen/base_example_gen_executor.py @@ -34,6 +34,7 @@ from tfx.dsl.components.base import base_executor from tfx.proto import example_gen_pb2 from tfx.types import artifact_utils +from tfx.types import standard_component_specs from tfx.utils import proto_utils from tfx_bsl.telemetry import util @@ -212,12 +213,14 @@ def GenerateExamplesByBeam( """ # Get input split information. input_config = example_gen_pb2.Input() - proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY], - input_config) + proto_utils.json_to_proto( + exec_properties[standard_component_specs.INPUT_CONFIG_KEY], + input_config) # Get output split information. output_config = example_gen_pb2.Output() - proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY], - output_config) + proto_utils.json_to_proto( + exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY], + output_config) # Get output split names. split_names = utils.generate_output_split_names(input_config, output_config) # Make beam_pipeline_args available in exec_properties since certain @@ -295,14 +298,16 @@ def Do( self._log_startup(input_dict, output_dict, exec_properties) input_config = example_gen_pb2.Input() - proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY], - input_config) + proto_utils.json_to_proto( + exec_properties[standard_component_specs.INPUT_CONFIG_KEY], + input_config) output_config = example_gen_pb2.Output() - proto_utils.json_to_proto(exec_properties[utils.OUTPUT_CONFIG_KEY], - output_config) + proto_utils.json_to_proto( + exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY], + output_config) examples_artifact = artifact_utils.get_single_instance( - output_dict[utils.EXAMPLES_KEY]) + output_dict[standard_component_specs.EXAMPLES_KEY]) examples_artifact.split_names = artifact_utils.encode_split_names( utils.generate_output_split_names(input_config, output_config)) @@ -314,13 +319,16 @@ def Do( for split_name, example_split in example_splits.items(): (example_split | 'WriteSplit[{}]'.format(split_name) >> _WriteSplit( - artifact_utils.get_split_uri(output_dict[utils.EXAMPLES_KEY], - split_name))) + artifact_utils.get_split_uri( + output_dict[standard_component_specs.EXAMPLES_KEY], + split_name))) # pylint: enable=expression-not-assigned, no-value-for-parameter - output_payload_format = exec_properties.get(utils.OUTPUT_DATA_FORMAT_KEY) + output_payload_format = exec_properties.get( + standard_component_specs.OUTPUT_DATA_FORMAT_KEY) if output_payload_format: - for output_examples_artifact in output_dict[utils.EXAMPLES_KEY]: - examples_utils.set_payload_format( - output_examples_artifact, output_payload_format) + for output_examples_artifact in output_dict[ + standard_component_specs.EXAMPLES_KEY]: + examples_utils.set_payload_format(output_examples_artifact, + output_payload_format) logging.info('Examples generated.') diff --git a/tfx/components/example_gen/base_example_gen_executor_test.py b/tfx/components/example_gen/base_example_gen_executor_test.py index 01f96d49e9..3c11b553b7 100644 --- a/tfx/components/example_gen/base_example_gen_executor_test.py +++ b/tfx/components/example_gen/base_example_gen_executor_test.py @@ -19,17 +19,17 @@ from __future__ import print_function import os + import apache_beam as beam from apache_beam.metrics.metric import MetricsFilter from apache_beam.runners.direct import direct_runner import tensorflow as tf - from tfx.components.example_gen import base_example_gen_executor -from tfx.components.example_gen import utils from tfx.dsl.io import fileio from tfx.proto import example_gen_pb2 from tfx.types import artifact_utils from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import proto_utils @@ -94,7 +94,9 @@ def setUp(self): # Create output dict. self._examples = standard_artifacts.Examples() self._examples.uri = self._output_data_dir - self._output_dict = {utils.EXAMPLES_KEY: [self._examples]} + self._output_dict = { + standard_component_specs.EXAMPLES_KEY: [self._examples] + } self._train_output_file = os.path.join(self._examples.uri, 'train', 'data_tfrecord-00000-of-00001.gz') @@ -103,13 +105,13 @@ def setUp(self): # Create exec proterties for output splits. self._exec_properties = { - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='single', pattern='single/*'), ])), - utils.OUTPUT_CONFIG_KEY: + standard_component_specs.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ @@ -141,14 +143,14 @@ def _testDo(self): def testDoInputSplit(self): # Create exec proterties for input split. self._exec_properties = { - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='train', pattern='train/*'), example_gen_pb2.Input.Split(name='eval', pattern='eval/*') ])), - utils.OUTPUT_CONFIG_KEY: + standard_component_specs.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json(example_gen_pb2.Output()) } @@ -170,16 +172,17 @@ def testDoOutputSplitWithSequenceExample(self): self._testDo() def _testFeatureBasedPartition(self, partition_feature_name): - self._exec_properties[utils.OUTPUT_CONFIG_KEY] = proto_utils.proto_to_json( - example_gen_pb2.Output( - split_config=example_gen_pb2.SplitConfig( - splits=[ - example_gen_pb2.SplitConfig.Split( - name='train', hash_buckets=2), - example_gen_pb2.SplitConfig.Split( - name='eval', hash_buckets=1) - ], - partition_feature_name=partition_feature_name))) + self._exec_properties[ + standard_component_specs.OUTPUT_CONFIG_KEY] = proto_utils.proto_to_json( + example_gen_pb2.Output( + split_config=example_gen_pb2.SplitConfig( + splits=[ + example_gen_pb2.SplitConfig.Split( + name='train', hash_buckets=2), + example_gen_pb2.SplitConfig.Split( + name='eval', hash_buckets=1) + ], + partition_feature_name=partition_feature_name))) def testFeatureBasedPartition(self): # Update exec proterties. diff --git a/tfx/components/example_gen/component_test.py b/tfx/components/example_gen/component_test.py index 5bac3407ea..6e5f3b1aea 100644 --- a/tfx/components/example_gen/component_test.py +++ b/tfx/components/example_gen/component_test.py @@ -27,6 +27,7 @@ from tfx.proto import example_gen_pb2 from tfx.proto import range_config_pb2 from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import proto_utils from google.protobuf import any_pb2 @@ -83,11 +84,16 @@ def testConstructSubclassQueryBased(self): ])) self.assertEqual({}, example_gen.inputs.get_all()) self.assertEqual(base_driver.BaseDriver, example_gen.driver_class) - self.assertEqual(standard_artifacts.Examples.TYPE_NAME, - example_gen.outputs['examples'].type_name) - self.assertEqual(example_gen.exec_properties['output_data_format'], - example_gen_pb2.FORMAT_TF_EXAMPLE) - self.assertIsNone(example_gen.exec_properties.get('custom_config')) + self.assertEqual( + standard_artifacts.Examples.TYPE_NAME, + example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) + self.assertEqual( + example_gen.exec_properties[ + standard_component_specs.OUTPUT_DATA_FORMAT_KEY], + example_gen_pb2.FORMAT_TF_EXAMPLE) + self.assertIsNone( + example_gen.exec_properties.get( + standard_component_specs.CUSTOM_CONFIG_KEY)) def testConstructSubclassQueryBasedWithInvalidOutputDataFormat(self): self.assertRaises( @@ -101,11 +107,15 @@ def testConstructSubclassQueryBasedWithInvalidOutputDataFormat(self): def testConstructSubclassFileBased(self): example_gen = TestFileBasedExampleGenComponent(input_base='path') - self.assertIn('input_base', example_gen.exec_properties) + self.assertIn(standard_component_specs.INPUT_BASE_KEY, + example_gen.exec_properties) self.assertEqual(driver.Driver, example_gen.driver_class) - self.assertEqual(standard_artifacts.Examples.TYPE_NAME, - example_gen.outputs['examples'].type_name) - self.assertIsNone(example_gen.exec_properties.get('custom_config')) + self.assertEqual( + standard_artifacts.Examples.TYPE_NAME, + example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) + self.assertIsNone( + example_gen.exec_properties.get( + standard_component_specs.CUSTOM_CONFIG_KEY)) def testConstructCustomExecutor(self): example_gen = component.FileBasedExampleGen( @@ -113,8 +123,9 @@ def testConstructCustomExecutor(self): custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) self.assertEqual(driver.Driver, example_gen.driver_class) - self.assertEqual(standard_artifacts.Examples.TYPE_NAME, - example_gen.outputs['examples'].type_name) + self.assertEqual( + standard_artifacts.Examples.TYPE_NAME, + example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) def testConstructWithOutputConfig(self): output_config = example_gen_pb2.Output( @@ -125,12 +136,14 @@ def testConstructWithOutputConfig(self): ])) example_gen = TestFileBasedExampleGenComponent( input_base='path', output_config=output_config) - self.assertEqual(standard_artifacts.Examples.TYPE_NAME, - example_gen.outputs['examples'].type_name) + self.assertEqual( + standard_artifacts.Examples.TYPE_NAME, + example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) stored_output_config = example_gen_pb2.Output() - proto_utils.json_to_proto(example_gen.exec_properties['output_config'], - stored_output_config) + proto_utils.json_to_proto( + example_gen.exec_properties[standard_component_specs.OUTPUT_CONFIG_KEY], + stored_output_config) self.assertEqual(output_config, stored_output_config) def testConstructWithInputConfig(self): @@ -141,12 +154,14 @@ def testConstructWithInputConfig(self): ]) example_gen = TestFileBasedExampleGenComponent( input_base='path', input_config=input_config) - self.assertEqual(standard_artifacts.Examples.TYPE_NAME, - example_gen.outputs['examples'].type_name) + self.assertEqual( + standard_artifacts.Examples.TYPE_NAME, + example_gen.outputs[standard_component_specs.EXAMPLES_KEY].type_name) stored_input_config = example_gen_pb2.Input() - proto_utils.json_to_proto(example_gen.exec_properties['input_config'], - stored_input_config) + proto_utils.json_to_proto( + example_gen.exec_properties[standard_component_specs.INPUT_CONFIG_KEY], + stored_input_config) self.assertEqual(input_config, stored_input_config) def testConstructWithCustomConfig(self): @@ -158,8 +173,9 @@ def testConstructWithCustomConfig(self): TestExampleGenExecutor)) stored_custom_config = example_gen_pb2.CustomConfig() - proto_utils.json_to_proto(example_gen.exec_properties['custom_config'], - stored_custom_config) + proto_utils.json_to_proto( + example_gen.exec_properties[standard_component_specs.CUSTOM_CONFIG_KEY], + stored_custom_config) self.assertEqual(custom_config, stored_custom_config) def testConstructWithStaticRangeConfig(self): @@ -172,8 +188,9 @@ def testConstructWithStaticRangeConfig(self): custom_executor_spec=executor_spec.ExecutorClassSpec( TestExampleGenExecutor)) stored_range_config = range_config_pb2.RangeConfig() - proto_utils.json_to_proto(example_gen.exec_properties['range_config'], - stored_range_config) + proto_utils.json_to_proto( + example_gen.exec_properties[standard_component_specs.RANGE_CONFIG_KEY], + stored_range_config) self.assertEqual(range_config, stored_range_config) diff --git a/tfx/components/example_gen/csv_example_gen/executor.py b/tfx/components/example_gen/csv_example_gen/executor.py index c5016345c0..06e4405aaa 100644 --- a/tfx/components/example_gen/csv_example_gen/executor.py +++ b/tfx/components/example_gen/csv_example_gen/executor.py @@ -24,10 +24,9 @@ from absl import logging import apache_beam as beam import tensorflow as tf - -from tfx.components.example_gen import utils from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor from tfx.dsl.io import fileio +from tfx.types import standard_component_specs from tfx.utils import io_utils from tfx_bsl.coders import csv_decoder @@ -122,7 +121,7 @@ def _CsvToExample( # pylint: disable=invalid-name Raises: RuntimeError: if split is empty or csv headers are not equal. """ - input_base_uri = exec_properties[utils.INPUT_BASE_KEY] + input_base_uri = exec_properties[standard_component_specs.INPUT_BASE_KEY] csv_pattern = os.path.join(input_base_uri, split_pattern) logging.info('Processing input csv data %s to TFExample.', csv_pattern) diff --git a/tfx/components/example_gen/csv_example_gen/executor_test.py b/tfx/components/example_gen/csv_example_gen/executor_test.py index ee79e82907..764b564992 100644 --- a/tfx/components/example_gen/csv_example_gen/executor_test.py +++ b/tfx/components/example_gen/csv_example_gen/executor_test.py @@ -19,16 +19,16 @@ from __future__ import print_function import os + import apache_beam as beam from apache_beam.testing import util import tensorflow as tf - -from tfx.components.example_gen import utils from tfx.components.example_gen.csv_example_gen import executor from tfx.dsl.io import fileio from tfx.proto import example_gen_pb2 from tfx.types import artifact_utils from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import proto_utils @@ -45,7 +45,9 @@ def testCsvToExample(self): examples = ( pipeline | 'ToTFExample' >> executor._CsvToExample( - exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir}, + exec_properties={ + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir + }, split_pattern='csv/*')) def check_results(results): @@ -61,7 +63,9 @@ def testCsvToExampleWithEmptyColumn(self): examples = ( pipeline | 'ToTFExample' >> executor._CsvToExample( - exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir}, + exec_properties={ + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir + }, split_pattern='csv_empty/*')) def check_results(results): @@ -88,18 +92,18 @@ def testDo(self): # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir - output_dict = {utils.EXAMPLES_KEY: [examples]} + output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { - utils.INPUT_BASE_KEY: + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir, - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split(name='csv', pattern='csv/*'), ])), - utils.OUTPUT_CONFIG_KEY: + standard_component_specs.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ diff --git a/tfx/components/example_gen/custom_executors/avro_executor.py b/tfx/components/example_gen/custom_executors/avro_executor.py index d7187807e5..dee47c548c 100644 --- a/tfx/components/example_gen/custom_executors/avro_executor.py +++ b/tfx/components/example_gen/custom_executors/avro_executor.py @@ -27,6 +27,7 @@ from tfx.components.example_gen import utils from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor +from tfx.types import standard_component_specs @beam.ptransform_fn @@ -49,7 +50,7 @@ def _AvroToExample( # pylint: disable=invalid-name Returns: PCollection of TF examples. """ - input_base_uri = exec_properties[utils.INPUT_BASE_KEY] + input_base_uri = exec_properties[standard_component_specs.INPUT_BASE_KEY] avro_pattern = os.path.join(input_base_uri, split_pattern) logging.info('Processing input avro data %s to TFExample.', avro_pattern) diff --git a/tfx/components/example_gen/custom_executors/avro_executor_test.py b/tfx/components/example_gen/custom_executors/avro_executor_test.py index 4f41901a92..ac526bfe33 100644 --- a/tfx/components/example_gen/custom_executors/avro_executor_test.py +++ b/tfx/components/example_gen/custom_executors/avro_executor_test.py @@ -23,12 +23,12 @@ import apache_beam as beam from apache_beam.testing import util import tensorflow as tf -from tfx.components.example_gen import utils from tfx.components.example_gen.custom_executors import avro_executor from tfx.dsl.io import fileio from tfx.proto import example_gen_pb2 from tfx.types import artifact_utils from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import proto_utils @@ -45,7 +45,9 @@ def testAvroToExample(self): examples = ( pipeline | 'ToTFExample' >> avro_executor._AvroToExample( - exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir}, + exec_properties={ + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir + }, split_pattern='avro/*.avro')) def check_result(got): @@ -64,19 +66,19 @@ def testDo(self): # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir - output_dict = {utils.EXAMPLES_KEY: [examples]} + output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { - utils.INPUT_BASE_KEY: + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir, - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='avro', pattern='avro/*.avro'), ])), - utils.OUTPUT_CONFIG_KEY: + standard_component_specs.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ diff --git a/tfx/components/example_gen/custom_executors/parquet_executor.py b/tfx/components/example_gen/custom_executors/parquet_executor.py index 0f9fc9d95c..f719d9b906 100644 --- a/tfx/components/example_gen/custom_executors/parquet_executor.py +++ b/tfx/components/example_gen/custom_executors/parquet_executor.py @@ -24,9 +24,9 @@ from absl import logging import apache_beam as beam import tensorflow as tf - from tfx.components.example_gen import utils from tfx.components.example_gen.base_example_gen_executor import BaseExampleGenExecutor +from tfx.types import standard_component_specs @beam.ptransform_fn @@ -49,7 +49,7 @@ def _ParquetToExample( # pylint: disable=invalid-name Returns: PCollection of TF examples. """ - input_base_uri = exec_properties[utils.INPUT_BASE_KEY] + input_base_uri = exec_properties[standard_component_specs.INPUT_BASE_KEY] parquet_pattern = os.path.join(input_base_uri, split_pattern) logging.info('Processing input parquet data %s to TFExample.', parquet_pattern) diff --git a/tfx/components/example_gen/custom_executors/parquet_executor_test.py b/tfx/components/example_gen/custom_executors/parquet_executor_test.py index 5a8b6b18f2..619e72ee64 100644 --- a/tfx/components/example_gen/custom_executors/parquet_executor_test.py +++ b/tfx/components/example_gen/custom_executors/parquet_executor_test.py @@ -23,12 +23,12 @@ import apache_beam as beam from apache_beam.testing import util import tensorflow as tf -from tfx.components.example_gen import utils from tfx.components.example_gen.custom_executors import parquet_executor from tfx.dsl.io import fileio from tfx.proto import example_gen_pb2 from tfx.types import artifact_utils from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import proto_utils @@ -45,7 +45,9 @@ def testParquetToExample(self): examples = ( pipeline | 'ToTFExample' >> parquet_executor._ParquetToExample( - exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir}, + exec_properties={ + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir + }, split_pattern='parquet/*')) def check_result(got): @@ -64,19 +66,19 @@ def testDo(self): # Create output dict. examples = standard_artifacts.Examples() examples.uri = output_data_dir - output_dict = {utils.EXAMPLES_KEY: [examples]} + output_dict = {standard_component_specs.EXAMPLES_KEY: [examples]} # Create exec proterties. exec_properties = { - utils.INPUT_BASE_KEY: + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir, - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( name='parquet', pattern='parquet/*'), ])), - utils.OUTPUT_CONFIG_KEY: + standard_component_specs.OUTPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Output( split_config=example_gen_pb2.SplitConfig(splits=[ diff --git a/tfx/components/example_gen/driver.py b/tfx/components/example_gen/driver.py index 0009a57fe6..37bc26d153 100644 --- a/tfx/components/example_gen/driver.py +++ b/tfx/components/example_gen/driver.py @@ -34,6 +34,7 @@ from tfx.proto import example_gen_pb2 from tfx.proto import range_config_pb2 from tfx.proto.orchestration import driver_output_pb2 +from tfx.types import standard_component_specs from tfx.utils import proto_utils from ml_metadata.proto import metadata_store_pb2 @@ -85,14 +86,16 @@ def resolve_exec_properties( del pipeline_info, component_info input_config = example_gen_pb2.Input() - proto_utils.json_to_proto(exec_properties[utils.INPUT_CONFIG_KEY], - input_config) + proto_utils.json_to_proto( + exec_properties[standard_component_specs.INPUT_CONFIG_KEY], + input_config) - input_base = exec_properties[utils.INPUT_BASE_KEY] + input_base = exec_properties[standard_component_specs.INPUT_BASE_KEY] logging.debug('Processing input %s.', input_base) range_config = None - range_config_entry = exec_properties.get(utils.RANGE_CONFIG_KEY) + range_config_entry = exec_properties.get( + standard_component_specs.RANGE_CONFIG_KEY) if range_config_entry: range_config = range_config_pb2.RangeConfig() proto_utils.json_to_proto(range_config_entry, range_config) @@ -111,8 +114,8 @@ def resolve_exec_properties( fingerprint, span, version = utils.calculate_splits_fingerprint_span_and_version( input_base, input_config.splits, range_config) - exec_properties[utils.INPUT_CONFIG_KEY] = proto_utils.proto_to_json( - input_config) + exec_properties[standard_component_specs + .INPUT_CONFIG_KEY] = proto_utils.proto_to_json(input_config) exec_properties[utils.SPAN_PROPERTY_NAME] = span exec_properties[utils.VERSION_PROPERTY_NAME] = version exec_properties[utils.FINGERPRINT_PROPERTY_NAME] = fingerprint @@ -131,16 +134,16 @@ def _prepare_output_artifacts( """Overrides BaseDriver._prepare_output_artifacts().""" del input_artifacts - example_artifact = output_dict[utils.EXAMPLES_KEY].type() + example_artifact = output_dict[standard_component_specs.EXAMPLES_KEY].type() base_output_dir = os.path.join(pipeline_info.pipeline_root, component_info.component_id) example_artifact.uri = base_driver._generate_output_uri( # pylint: disable=protected-access - base_output_dir, utils.EXAMPLES_KEY, execution_id) + base_output_dir, standard_component_specs.EXAMPLES_KEY, execution_id) update_output_artifact(exec_properties, example_artifact.mlmd_artifact) base_driver._prepare_output_paths(example_artifact) # pylint: disable=protected-access - return {utils.EXAMPLES_KEY: [example_artifact]} + return {standard_component_specs.EXAMPLES_KEY: [example_artifact]} def run( self, execution_info: portable_data_types.ExecutionInfo @@ -159,8 +162,9 @@ def run( data_types_utils.set_metadata_value(result.exec_properties[k], v) # Populate output_dict - output_example = copy.deepcopy( - execution_info.output_dict[utils.EXAMPLES_KEY][0].mlmd_artifact) + output_example = copy.deepcopy(execution_info.output_dict[ + standard_component_specs.EXAMPLES_KEY][0].mlmd_artifact) update_output_artifact(exec_properties, output_example) - result.output_artifacts[utils.EXAMPLES_KEY].artifacts.append(output_example) + result.output_artifacts[ + standard_component_specs.EXAMPLES_KEY].artifacts.append(output_example) return result diff --git a/tfx/components/example_gen/driver_test.py b/tfx/components/example_gen/driver_test.py index b22188503f..c4e5f36c5c 100644 --- a/tfx/components/example_gen/driver_test.py +++ b/tfx/components/example_gen/driver_test.py @@ -32,6 +32,7 @@ from tfx.types import artifact_utils from tfx.types import channel_utils from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import io_utils from tfx.utils import proto_utils @@ -56,9 +57,9 @@ def testResolveExecProperties(self): # Create exec proterties. self._exec_properties = { - utils.INPUT_BASE_KEY: + standard_component_specs.INPUT_BASE_KEY: self._input_base_path, - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( @@ -68,7 +69,7 @@ def testResolveExecProperties(self): name='s2', pattern='span{SPAN}/version{VERSION}/split2/*') ])), - utils.RANGE_CONFIG_KEY: + standard_component_specs.RANGE_CONFIG_KEY: None, } @@ -117,8 +118,9 @@ def testResolveExecProperties(self): r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) updated_input_config = example_gen_pb2.Input() - proto_utils.json_to_proto(self._exec_properties[utils.INPUT_CONFIG_KEY], - updated_input_config) + proto_utils.json_to_proto( + self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY], + updated_input_config) # Check if latest span is selected. self.assertProtoEquals( @@ -133,27 +135,31 @@ def testResolveExecProperties(self): }""", updated_input_config) # Test driver behavior using RangeConfig with static range. - self._exec_properties[utils.INPUT_CONFIG_KEY] = proto_utils.proto_to_json( - example_gen_pb2.Input(splits=[ - example_gen_pb2.Input.Split( - name='s1', pattern='span{SPAN:2}/version{VERSION}/split1/*'), - example_gen_pb2.Input.Split( - name='s2', pattern='span{SPAN:2}/version{VERSION}/split2/*') - ])) + self._exec_properties[ + standard_component_specs.INPUT_CONFIG_KEY] = proto_utils.proto_to_json( + example_gen_pb2.Input(splits=[ + example_gen_pb2.Input.Split( + name='s1', + pattern='span{SPAN:2}/version{VERSION}/split1/*'), + example_gen_pb2.Input.Split( + name='s2', pattern='span{SPAN:2}/version{VERSION}/split2/*') + ])) - self._exec_properties[utils.RANGE_CONFIG_KEY] = proto_utils.proto_to_json( - range_config_pb2.RangeConfig( - static_range=range_config_pb2.StaticRange( - start_span_number=1, end_span_number=2))) + self._exec_properties[ + standard_component_specs.RANGE_CONFIG_KEY] = proto_utils.proto_to_json( + range_config_pb2.RangeConfig( + static_range=range_config_pb2.StaticRange( + start_span_number=1, end_span_number=2))) with self.assertRaisesRegexp( ValueError, 'Start and end span numbers for RangeConfig.static_range'): self._example_gen_driver.resolve_exec_properties(self._exec_properties, None, None) - self._exec_properties[utils.RANGE_CONFIG_KEY] = proto_utils.proto_to_json( - range_config_pb2.RangeConfig( - static_range=range_config_pb2.StaticRange( - start_span_number=1, end_span_number=1))) + self._exec_properties[ + standard_component_specs.RANGE_CONFIG_KEY] = proto_utils.proto_to_json( + range_config_pb2.RangeConfig( + static_range=range_config_pb2.StaticRange( + start_span_number=1, end_span_number=1))) self._example_gen_driver.resolve_exec_properties(self._exec_properties, None, None) self.assertEqual(self._exec_properties[utils.SPAN_PROPERTY_NAME], 1) @@ -163,8 +169,9 @@ def testResolveExecProperties(self): r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) updated_input_config = example_gen_pb2.Input() - proto_utils.json_to_proto(self._exec_properties[utils.INPUT_CONFIG_KEY], - updated_input_config) + proto_utils.json_to_proto( + self._exec_properties[standard_component_specs.INPUT_CONFIG_KEY], + updated_input_config) # Check if correct span inside static range is selected. self.assertProtoEquals( """ @@ -179,7 +186,10 @@ def testResolveExecProperties(self): def testPrepareOutputArtifacts(self): examples = standard_artifacts.Examples() - output_dict = {utils.EXAMPLES_KEY: channel_utils.as_channel([examples])} + output_dict = { + standard_component_specs.EXAMPLES_KEY: + channel_utils.as_channel([examples]) + } exec_properties = { utils.SPAN_PROPERTY_NAME: 2, utils.VERSION_PROPERTY_NAME: 1, @@ -196,7 +206,7 @@ def testPrepareOutputArtifacts(self): input_artifacts, output_dict, exec_properties, 1, pipeline_info, component_info) examples = artifact_utils.get_single_instance( - output_artifacts[utils.EXAMPLES_KEY]) + output_artifacts[standard_component_specs.EXAMPLES_KEY]) base_output_dir = os.path.join(self._test_dir, component_info.component_id) expected_uri = base_driver._generate_output_uri( # pylint: disable=protected-access base_output_dir, 'examples', 1) @@ -228,13 +238,13 @@ def testDriverRunFn(self): # Prepare output_dic example.uri = 'my_uri' # Will verify that this uri is not changed. - output_dic = {utils.EXAMPLES_KEY: [example]} + output_dic = {standard_component_specs.EXAMPLES_KEY: [example]} # Prepare output_dic exec_proterties. exec_properties = { - utils.INPUT_BASE_KEY: + standard_component_specs.INPUT_BASE_KEY: self._input_base_path, - utils.INPUT_CONFIG_KEY: + standard_component_specs.INPUT_CONFIG_KEY: proto_utils.proto_to_json( example_gen_pb2.Input(splits=[ example_gen_pb2.Input.Split( @@ -251,7 +261,7 @@ def testDriverRunFn(self): self.assertEqual(exec_properties[utils.SPAN_PROPERTY_NAME].int_value, 1) updated_input_config = example_gen_pb2.Input() proto_utils.json_to_proto( - exec_properties[utils.INPUT_CONFIG_KEY].string_value, + exec_properties[standard_component_specs.INPUT_CONFIG_KEY].string_value, updated_input_config) self.assertProtoEquals( """ @@ -268,8 +278,11 @@ def testDriverRunFn(self): r'split:s1,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*\nsplit:s2,num_files:1,total_bytes:9,xor_checksum:.*,sum_checksum:.*' ) # Assert output_artifacts' values - self.assertLen(result.output_artifacts[utils.EXAMPLES_KEY].artifacts, 1) - output_example = result.output_artifacts[utils.EXAMPLES_KEY].artifacts[0] + self.assertLen( + result.output_artifacts[ + standard_component_specs.EXAMPLES_KEY].artifacts, 1) + output_example = result.output_artifacts[ + standard_component_specs.EXAMPLES_KEY].artifacts[0] self.assertEqual(output_example.uri, example.uri) self.assertEqual( output_example.custom_properties[utils.SPAN_PROPERTY_NAME].string_value, diff --git a/tfx/components/example_gen/import_example_gen/executor.py b/tfx/components/example_gen/import_example_gen/executor.py index 7c11fdcc1b..16cfeac707 100644 --- a/tfx/components/example_gen/import_example_gen/executor.py +++ b/tfx/components/example_gen/import_example_gen/executor.py @@ -26,8 +26,8 @@ import tensorflow as tf from tfx.components.example_gen import base_example_gen_executor -from tfx.components.example_gen import utils from tfx.proto import example_gen_pb2 +from tfx.types import standard_component_specs @beam.ptransform_fn @@ -50,7 +50,7 @@ def _ImportSerializedRecord( # pylint: disable=invalid-name Returns: PCollection of records (tf.Example, tf.SequenceExample, or bytes). """ - input_base_uri = exec_properties[utils.INPUT_BASE_KEY] + input_base_uri = exec_properties[standard_component_specs.INPUT_BASE_KEY] input_split_pattern = os.path.join(input_base_uri, split_pattern) logging.info('Reading input TFRecord data %s.', input_split_pattern) @@ -88,7 +88,8 @@ def ImportRecord(pipeline: beam.Pipeline, exec_properties: Dict[Text, Any], Returns: PCollection of records (tf.Example, tf.SequenceExample, or bytes). """ - output_payload_format = exec_properties.get(utils.OUTPUT_DATA_FORMAT_KEY) + output_payload_format = exec_properties.get( + standard_component_specs.OUTPUT_DATA_FORMAT_KEY) serialized_records = ( pipeline diff --git a/tfx/components/example_gen/import_example_gen/executor_test.py b/tfx/components/example_gen/import_example_gen/executor_test.py index 5927a6d70e..4b2c760c18 100644 --- a/tfx/components/example_gen/import_example_gen/executor_test.py +++ b/tfx/components/example_gen/import_example_gen/executor_test.py @@ -19,16 +19,17 @@ from __future__ import print_function import os + import apache_beam as beam from apache_beam.testing import util import tensorflow as tf - from tfx.components.example_gen import utils from tfx.components.example_gen.import_example_gen import executor from tfx.dsl.io import fileio from tfx.proto import example_gen_pb2 from tfx.types import artifact_utils from tfx.types import standard_artifacts +from tfx.types import standard_component_specs from tfx.utils import proto_utils @@ -57,7 +58,9 @@ def testImportExample(self): examples = ( pipeline | 'ToSerializedRecord' >> executor._ImportSerializedRecord( - exec_properties={utils.INPUT_BASE_KEY: self._input_data_dir}, + exec_properties={ + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir + }, split_pattern='tfrecord/*') | 'ToTFExample' >> beam.Map(tf.train.Example.FromString)) @@ -71,10 +74,10 @@ def check_result(got): def _testDo(self, payload_format): exec_properties = { - utils.INPUT_BASE_KEY: self._input_data_dir, - utils.INPUT_CONFIG_KEY: self._input_config, - utils.OUTPUT_CONFIG_KEY: self._output_config, - utils.OUTPUT_DATA_FORMAT_KEY: payload_format, + standard_component_specs.INPUT_BASE_KEY: self._input_data_dir, + standard_component_specs.INPUT_CONFIG_KEY: self._input_config, + standard_component_specs.OUTPUT_CONFIG_KEY: self._output_config, + standard_component_specs.OUTPUT_DATA_FORMAT_KEY: payload_format, } output_data_dir = os.path.join( @@ -84,7 +87,7 @@ def _testDo(self, payload_format): # Create output dict. self.examples = standard_artifacts.Examples() self.examples.uri = output_data_dir - output_dict = {utils.EXAMPLES_KEY: [self.examples]} + output_dict = {standard_component_specs.EXAMPLES_KEY: [self.examples]} # Run executor. import_example_gen = executor.Executor() diff --git a/tfx/orchestration/kubeflow/v2/container/kubeflow_v2_entrypoint_utils.py b/tfx/orchestration/kubeflow/v2/container/kubeflow_v2_entrypoint_utils.py index 851bbb34ff..7e80c49723 100644 --- a/tfx/orchestration/kubeflow/v2/container/kubeflow_v2_entrypoint_utils.py +++ b/tfx/orchestration/kubeflow/v2/container/kubeflow_v2_entrypoint_utils.py @@ -18,11 +18,11 @@ from absl import logging from tfx.components.evaluator import constants -from tfx.components.example_gen import utils from tfx.orchestration.kubeflow.v2 import compiler_utils from tfx.orchestration.kubeflow.v2.proto import pipeline_pb2 from tfx.types import artifact from tfx.types import artifact_utils +from tfx.types import standard_component_specs from tfx.utils import import_utils import yaml @@ -128,7 +128,7 @@ def parse_execution_properties(exec_properties: Any) -> Dict[str, Any]: for k, v in exec_properties.items(): # TODO(b/159835994): Remove this once pipeline populates INPUT_BASE_KEY if k == _OLD_INPUT_BASE_PROPERTY_NAME: - k = utils.INPUT_BASE_KEY + k = standard_component_specs.INPUT_BASE_KEY # Translate each field from Value pb to plain value. result[k] = getattr(v, v.WhichOneof('value')) if result[k] is None: diff --git a/tfx/types/standard_component_specs.py b/tfx/types/standard_component_specs.py index 948344d786..7e2a1daa8b 100644 --- a/tfx/types/standard_component_specs.py +++ b/tfx/types/standard_component_specs.py @@ -40,6 +40,7 @@ SCHEMA_KEY = 'schema' EXAMPLES_KEY = 'examples' MODEL_KEY = 'model' +CUSTOM_CONFIG_KEY = 'custom_config' BLESSING_KEY = 'blessing' TRAIN_ARGS_KEY = 'train_args' CUSTOM_CONFIG_KEY = 'custom_config' @@ -75,6 +76,12 @@ OUTPUT_EXAMPLES_KEY = 'output_examples' # Key for schema_gen INFER_FEATURE_SHAPE_KEY = 'infer_feature_shape' +# Key for example_gen +INPUT_BASE_KEY = 'input_base' +INPUT_CONFIG_KEY = 'input_config' +OUTPUT_CONFIG_KEY = 'output_config' +OUTPUT_DATA_FORMAT_KEY = 'output_data_format' +RANGE_CONFIG_KEY = 'range_config' # Key for pusher PUSH_DESTINATION_KEY = 'push_destination' INFRA_BLESSING_KEY = 'infra_blessing' @@ -182,22 +189,22 @@ class FileBasedExampleGenSpec(ComponentSpec): """File-based ExampleGen component spec.""" PARAMETERS = { - 'input_base': + INPUT_BASE_KEY: ExecutionParameter(type=(str, Text)), - 'input_config': + INPUT_CONFIG_KEY: ExecutionParameter(type=example_gen_pb2.Input), - 'output_config': + OUTPUT_CONFIG_KEY: ExecutionParameter(type=example_gen_pb2.Output), - 'output_data_format': + OUTPUT_DATA_FORMAT_KEY: ExecutionParameter(type=int), # example_gen_pb2.PayloadType enum. - 'custom_config': + CUSTOM_CONFIG_KEY: ExecutionParameter(type=example_gen_pb2.CustomConfig, optional=True), - 'range_config': + RANGE_CONFIG_KEY: ExecutionParameter(type=range_config_pb2.RangeConfig, optional=True), } INPUTS = {} OUTPUTS = { - 'examples': ChannelParameter(type=standard_artifacts.Examples), + EXAMPLES_KEY: ChannelParameter(type=standard_artifacts.Examples), } @@ -205,18 +212,18 @@ class QueryBasedExampleGenSpec(ComponentSpec): """Query-based ExampleGen component spec.""" PARAMETERS = { - 'input_config': + INPUT_CONFIG_KEY: ExecutionParameter(type=example_gen_pb2.Input), - 'output_config': + OUTPUT_CONFIG_KEY: ExecutionParameter(type=example_gen_pb2.Output), - 'output_data_format': + OUTPUT_DATA_FORMAT_KEY: ExecutionParameter(type=int), # example_gen_pb2.PayloadType enum. - 'custom_config': + CUSTOM_CONFIG_KEY: ExecutionParameter(type=example_gen_pb2.CustomConfig, optional=True), } INPUTS = {} OUTPUTS = { - 'examples': ChannelParameter(type=standard_artifacts.Examples), + EXAMPLES_KEY: ChannelParameter(type=standard_artifacts.Examples), }