From 7f1ab7f91b0ce07844e4547b534ce22680324d47 Mon Sep 17 00:00:00 2001 From: tfx-team Date: Thu, 5 Mar 2020 11:54:54 -0800 Subject: [PATCH] Change model_to_estimator iris example to use generic trainer PiperOrigin-RevId: 299156767 --- RELEASE.md | 1 + .../taxi_utils_native_keras.py | 1 - tfx/examples/iris/iris_pipeline_beam.py | 5 +++ tfx/examples/iris/iris_utils.py | 35 +++++++++++++++++++ tfx/examples/iris/iris_utils_native_keras.py | 1 - 5 files changed, 41 insertions(+), 2 deletions(-) diff --git a/RELEASE.md b/RELEASE.md index 17c0e23b28..ca52034d2f 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -33,6 +33,7 @@ * Added a Chicago Taxi example with native Keras. * Updated TFLite converter to work with TF2. * Enabled filtering by artifact producer and output key in ResolverNode. +* Changed Iris model_to_estimator e2e example to use generic Trainer. ## Bug fixes and other changes * Added --skaffold_cmd flag when updating a pipeline for kubeflow in CLI. diff --git a/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py b/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py index e61ea996e9..96e7fcd408 100644 --- a/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py +++ b/tfx/examples/chicago_taxi_pipeline/taxi_utils_native_keras.py @@ -307,7 +307,6 @@ def run_fn(fn_args: TrainerFnArgs): train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 40) eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 40) - # If no GPUs are found, CPU is used. mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = _build_keras_model( diff --git a/tfx/examples/iris/iris_pipeline_beam.py b/tfx/examples/iris/iris_pipeline_beam.py index 6644a146c1..b844aa4234 100644 --- a/tfx/examples/iris/iris_pipeline_beam.py +++ b/tfx/examples/iris/iris_pipeline_beam.py @@ -31,6 +31,8 @@ from tfx.components import SchemaGen from tfx.components import StatisticsGen from tfx.components import Trainer +from tfx.components.base import executor_spec +from tfx.components.trainer.executor import GenericExecutor from tfx.orchestration import metadata from tfx.orchestration import pipeline from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner @@ -86,6 +88,9 @@ def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text, # Uses user-provided Python function that implements a model using TF-Learn. trainer = Trainer( module_file=module_file, + # GenericExecutor uses `run_fn`, while default estimator based executor + # uses `trainer_fn` instead. + custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor), examples=example_gen.outputs['examples'], schema=infer_schema.outputs['schema'], train_args=trainer_pb2.TrainArgs(num_steps=10000), diff --git a/tfx/examples/iris/iris_utils.py b/tfx/examples/iris/iris_utils.py index d4d92c9baf..d9fd6f775b 100644 --- a/tfx/examples/iris/iris_utils.py +++ b/tfx/examples/iris/iris_utils.py @@ -27,6 +27,10 @@ import tensorflow_model_analysis as tfma from tensorflow_transform.tf_metadata import schema_utils +from tensorflow_metadata.proto.v0 import schema_pb2 +from tfx.components.trainer import executor +from tfx.utils import io_utils + _FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width'] _LABEL_KEY = 'variety' @@ -188,3 +192,34 @@ def trainer_fn(trainer_fn_args, schema): 'eval_spec': eval_spec, 'eval_input_receiver_fn': eval_receiver_fn } + + +# TFX generic trainer will call this function instead of train_fn. +def run_fn(fn_args: executor.TrainerFnArgs): + """Train the model based on given args. + + Args: + fn_args: Holds args used to train the model as name/value pairs. + """ + schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema()) + + training_spec = trainer_fn(fn_args, schema) + + # Train the model + absl.logging.info('Training model.') + tf.estimator.train_and_evaluate(training_spec['estimator'], + training_spec['train_spec'], + training_spec['eval_spec']) + absl.logging.info('Training complete. Model written to %s', + fn_args.serving_model_dir) + + # Export an eval savedmodel for TFMA + # NOTE: When trained in distributed training cluster, eval_savedmodel must be + # exported only by the chief worker (check TF_CONFIG). + absl.logging.info('Exporting eval_savedmodel for TFMA.') + tfma.export.export_eval_savedmodel( + estimator=training_spec['estimator'], + export_dir_base=fn_args.eval_model_dir, + eval_input_receiver_fn=training_spec['eval_input_receiver_fn']) + + absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir) diff --git a/tfx/examples/iris/iris_utils_native_keras.py b/tfx/examples/iris/iris_utils_native_keras.py index 1b763e78da..480e097b77 100644 --- a/tfx/examples/iris/iris_utils_native_keras.py +++ b/tfx/examples/iris/iris_utils_native_keras.py @@ -150,7 +150,6 @@ def run_fn(fn_args: TrainerFnArgs): train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 40) eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 40) - # If no GPUs are found, CPU is used. mirrored_strategy = tf.distribute.MirroredStrategy() with mirrored_strategy.scope(): model = _build_keras_model()