Skip to content

Commit

Permalink
Change model_to_estimator iris example to use generic trainer
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 299156767
  • Loading branch information
tfx-copybara authored and tensorflow-extended-team committed Mar 5, 2020
1 parent 7152824 commit 7f1ab7f
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 2 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
* Added a Chicago Taxi example with native Keras.
* Updated TFLite converter to work with TF2.
* Enabled filtering by artifact producer and output key in ResolverNode.
* Changed Iris model_to_estimator e2e example to use generic Trainer.

## Bug fixes and other changes
* Added --skaffold_cmd flag when updating a pipeline for kubeflow in CLI.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,6 @@ def run_fn(fn_args: TrainerFnArgs):
train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 40)
eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 40)

# If no GPUs are found, CPU is used.
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = _build_keras_model(
Expand Down
5 changes: 5 additions & 0 deletions tfx/examples/iris/iris_pipeline_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
Expand Down Expand Up @@ -86,6 +88,9 @@ def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
module_file=module_file,
# GenericExecutor uses `run_fn`, while default estimator based executor
# uses `trainer_fn` instead.
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
Expand Down
35 changes: 35 additions & 0 deletions tfx/examples/iris/iris_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,10 @@
import tensorflow_model_analysis as tfma
from tensorflow_transform.tf_metadata import schema_utils

from tensorflow_metadata.proto.v0 import schema_pb2
from tfx.components.trainer import executor
from tfx.utils import io_utils

_FEATURE_KEYS = ['sepal_length', 'sepal_width', 'petal_length', 'petal_width']
_LABEL_KEY = 'variety'

Expand Down Expand Up @@ -188,3 +192,34 @@ def trainer_fn(trainer_fn_args, schema):
'eval_spec': eval_spec,
'eval_input_receiver_fn': eval_receiver_fn
}


# TFX generic trainer will call this function instead of train_fn.
def run_fn(fn_args: executor.TrainerFnArgs):
"""Train the model based on given args.
Args:
fn_args: Holds args used to train the model as name/value pairs.
"""
schema = io_utils.parse_pbtxt_file(fn_args.schema_file, schema_pb2.Schema())

training_spec = trainer_fn(fn_args, schema)

# Train the model
absl.logging.info('Training model.')
tf.estimator.train_and_evaluate(training_spec['estimator'],
training_spec['train_spec'],
training_spec['eval_spec'])
absl.logging.info('Training complete. Model written to %s',
fn_args.serving_model_dir)

# Export an eval savedmodel for TFMA
# NOTE: When trained in distributed training cluster, eval_savedmodel must be
# exported only by the chief worker (check TF_CONFIG).
absl.logging.info('Exporting eval_savedmodel for TFMA.')
tfma.export.export_eval_savedmodel(
estimator=training_spec['estimator'],
export_dir_base=fn_args.eval_model_dir,
eval_input_receiver_fn=training_spec['eval_input_receiver_fn'])

absl.logging.info('Exported eval_savedmodel to %s.', fn_args.eval_model_dir)
1 change: 0 additions & 1 deletion tfx/examples/iris/iris_utils_native_keras.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,6 @@ def run_fn(fn_args: TrainerFnArgs):
train_dataset = _input_fn(fn_args.train_files, tf_transform_output, 40)
eval_dataset = _input_fn(fn_args.eval_files, tf_transform_output, 40)

# If no GPUs are found, CPU is used.
mirrored_strategy = tf.distribute.MirroredStrategy()
with mirrored_strategy.scope():
model = _build_keras_model()
Expand Down

0 comments on commit 7f1ab7f

Please sign in to comment.