Skip to content

Commit

Permalink
Created a chicago taxi example with native keras
Browse files Browse the repository at this point in the history
Also changed input_fn in taxi_utils to return dataset instead of iterator to be consistent with native Keras style

PiperOrigin-RevId: 297774918
  • Loading branch information
tfx-copybara authored and tensorflow-extended-team committed Feb 28, 2020
1 parent 34dd823 commit dfcd914
Show file tree
Hide file tree
Showing 6 changed files with 710 additions and 35 deletions.
1 change: 1 addition & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
* Updated Evaluator's executor to support model validation.
* Introduced awareness of chief worker to Trainer's executor, in case running
in distributed training cluster.
* Added a Chicago Taxi example with native Keras.

## Bug fixes and other changes
* Added --skaffold_cmd flag when updating a pipeline for kubeflow in CLI.
Expand Down
49 changes: 32 additions & 17 deletions tfx/examples/chicago_taxi_pipeline/README.md
Original file line number Diff line number Diff line change
@@ -1,21 +1,23 @@
# Chicago Taxi Example

The Chicago Taxi example demonstrates the end-to-end workflow and steps of how
to analyze, validate and transform data, train a model, analyze and serve it. It
uses the following [TFX](https://www.tensorflow.org/tfx) components:
The Chicago Taxi example demonstrates the end-to-end workflow and the steps
required to analyze, validate, and transform data, train a model, analyze its
performance, and serve it. This example uses the following
[TFX](https://www.tensorflow.org/tfx) components:

* [ExampleGen](https://github.com/tensorflow/tfx/blob/master/docs/guide/examplegen.md)
ingests and splits the input dataset.
* [StatisticsGen](https://github.com/tensorflow/tfx/blob/master/docs/guide/statsgen.md)
calculates statistics for the dataset.
* [SchemaGen](https://github.com/tensorflow/tfx/blob/master/docs/guide/schemagen.md)
SchemaGen examines the statistics and creates a data schema.
examines the statistics and creates a data schema.
* [ExampleValidator](https://github.com/tensorflow/tfx/blob/master/docs/guide/exampleval.md)
looks for anomalies and missing values in the dataset.
* [Transform](https://github.com/tensorflow/tfx/blob/master/docs/guide/transform.md)
performs feature engineering on the dataset.
* [Trainer](https://github.com/tensorflow/tfx/blob/master/docs/guide/trainer.md)
trains the model using TensorFlow [Estimators](https://www.tensorflow.org/guide/estimators)
or [Keras](https://www.tensorflow.org/guide/keras).
* [Evaluator](https://github.com/tensorflow/tfx/blob/master/docs/guide/evaluator.md)
performs deep analysis of the training results.
* [ModelValidator](https://github.com/tensorflow/tfx/blob/master/docs/guide/modelval.md)
Expand Down Expand Up @@ -92,16 +94,9 @@ export TFX_DIR=~/tfx

Next, install the dependencies required by the Chicago Taxi example:

<!--- bring back once requirements.txt file is available
<pre class="devsite-terminal devsite-click-to-copy">
pip install -r requirements.txt
</pre>
-->

<pre class="devsite-terminal devsite-click-to-copy">
pip install tensorflow==1.14.0
pip install apache-airflow==1.10.5
pip install tfx==0.14.0
pip install apache-airflow==1.10.9
pip install tfx==0.21.0
</pre>

Next, initialize Airflow
Expand Down Expand Up @@ -227,6 +222,24 @@ as above local example with local directory changing to `gs://YOUR_BUCKET`.

For more information, see [TensorFlow Serving](https://www.tensorflow.org/serving).

# Chicago Taxi Beam Orchestrator Example

Instead of using Airflow as orchestrator, [beam example](https://github.com/tensorflow/tfx/blob/r0.21/tfx/examples/chicago_taxi_pipeline/taxi_pipeline_beam.py)
use [Beam as orchestrator](https://github.com/tensorflow/tfx/blob/r0.21/docs/guide/beam_orchestrator.md).

To run the example, install tfx in the virtualenv, and copy data
and user module file to $TAXI_DIR as above [instruction](#copy-the-pipeline-definition-to-airflows-dag-directory).
Then simply run `python taxi_pipeline_beam.py` to execute the pipeline.

# Chicago Taxi Kubeflow Orchestrator Example

Use [Kubeflow as orchestrator](https://github.com/tensorflow/tfx/blob/r0.21/docs/guide/kubeflow.md), check [here](https://github.com/kubeflow/pipelines/tree/master/samples/core/tfx-oss) for details.

# Chicago Taxi Native Keras Example (tfx 0.21.1)

Instead of estimator, this example uses native Keras in user module file
`taxi_utils_native_keras.py`.

# Chicago Taxi Flink Example

This section requires the [local prerequisites](#local_prerequisites) and adds a
Expand All @@ -248,8 +261,9 @@ This will start a local Beam Job Server.
The Apache Flink UI can be viewed at http://localhost:8081.

To run tfx e2e on Flink, open a new terminal and activate another instance of
same `virtualenv`. Follow above instructions of Chicago Taxi Example with
'taxi_pipeline_simple' replaced by 'taxi_pipeline_portable_beam'.
same `virtualenv`. Follow the setup for [beam orchestrator example](#chicago-taxi-beam-orchestrator-example),
and then run `python taxi_pipeline_portable_beam.py` to execute the pipeline
with Flink.

# Chicago Taxi Spark Example

Expand All @@ -274,8 +288,9 @@ http://localhost:4040 for the Spark application UI (while a job is running).


To run tfx e2e on Spark, open a new terminal and activate another instance of
same `virtualenv`. Follow above instructions of Chicago Taxi Example with
'taxi_pipeline_simple' replaced by 'taxi_pipeline_portable_beam'.
same `virtualenv`. Follow the setup for [beam orchestrator example](#chicago-taxi-beam-orchestrator-example),
and then run `python taxi_pipeline_portable_beam.py` to execute the pipeline
with Spark.

# Learn more

Expand Down
191 changes: 191 additions & 0 deletions tfx/examples/chicago_taxi_pipeline/taxi_pipeline_native_keras.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Chicago taxi example using TFX."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
from typing import Text

import absl
import tensorflow_model_analysis as tfma

from tfx.components import CsvExampleGen
from tfx.components import Evaluator
from tfx.components import ExampleValidator
from tfx.components import Pusher
from tfx.components import ResolverNode
from tfx.components import SchemaGen
from tfx.components import StatisticsGen
from tfx.components import Trainer
from tfx.components import Transform
from tfx.components.base import executor_spec
from tfx.components.trainer.executor import GenericExecutor
from tfx.dsl.experimental import latest_blessed_model_resolver
from tfx.orchestration import metadata
from tfx.orchestration import pipeline
from tfx.orchestration.beam.beam_dag_runner import BeamDagRunner
from tfx.proto import pusher_pb2
from tfx.proto import trainer_pb2
from tfx.types import Channel
from tfx.types.standard_artifacts import Model
from tfx.types.standard_artifacts import ModelBlessing
from tfx.utils.dsl_utils import external_input

_pipeline_name = 'chicago_taxi_native_keras'

# This example assumes that the taxi data is stored in ~/taxi/data and the
# taxi utility function is in ~/taxi. Feel free to customize this as needed.
_taxi_root = os.path.join(os.environ['HOME'], 'taxi')
_data_root = os.path.join(_taxi_root, 'data', 'simple')
# Python module file to inject customized logic into the TFX components. The
# Transform and Trainer both require user-defined functions to run successfully.
_module_file = os.path.join(_taxi_root, 'taxi_utils_native_keras.py')
# Path which can be listened to by the model server. Pusher will output the
# trained model here.
_serving_model_dir = os.path.join(_taxi_root, 'serving_model', _pipeline_name)

# Directory and data locations. This example assumes all of the chicago taxi
# example code and metadata library is relative to $HOME, but you can store
# these files anywhere on your local filesystem.
_tfx_root = os.path.join(os.environ['HOME'], 'tfx')
_pipeline_root = os.path.join(_tfx_root, 'pipelines', _pipeline_name)
# Sqlite ML-metadata db path.
_metadata_path = os.path.join(_tfx_root, 'metadata', _pipeline_name,
'metadata.db')


# TODO(b/137289334): rename this as simple after DAG visualization is done.
def _create_pipeline(pipeline_name: Text, pipeline_root: Text, data_root: Text,
module_file: Text, serving_model_dir: Text,
metadata_path: Text,
direct_num_workers: int) -> pipeline.Pipeline:
"""Implements the chicago taxi pipeline with TFX."""
examples = external_input(data_root)

# Brings data into the pipeline or otherwise joins/converts training data.
example_gen = CsvExampleGen(input=examples)

# Computes statistics over data for visualization and example validation.
statistics_gen = StatisticsGen(examples=example_gen.outputs['examples'])

# Generates schema based on statistics files.
infer_schema = SchemaGen(
statistics=statistics_gen.outputs['statistics'],
infer_feature_shape=False)

# Performs anomaly detection based on statistics and data schema.
validate_stats = ExampleValidator(
statistics=statistics_gen.outputs['statistics'],
schema=infer_schema.outputs['schema'])

# Performs transformations and feature engineering in training and serving.
transform = Transform(
examples=example_gen.outputs['examples'],
schema=infer_schema.outputs['schema'],
module_file=module_file)

# Uses user-provided Python function that implements a model using TF-Learn.
trainer = Trainer(
module_file=module_file,
custom_executor_spec=executor_spec.ExecutorClassSpec(GenericExecutor),
examples=transform.outputs['transformed_examples'],
transform_graph=transform.outputs['transform_graph'],
schema=infer_schema.outputs['schema'],
train_args=trainer_pb2.TrainArgs(num_steps=10000),
eval_args=trainer_pb2.EvalArgs(num_steps=5000))

# Get the latest blessed model for model validation.
model_resolver = ResolverNode(
instance_name='latest_blessed_model_resolver',
resolver_class=latest_blessed_model_resolver.LatestBlessedModelResolver,
model=Channel(type=Model),
model_blessing=Channel(type=ModelBlessing))

# Uses TFMA to compute a evaluation statistics over features of a model and
# perform quality validation of a candidate model (compared to a baseline).
eval_config = tfma.EvalConfig(
model_specs=[
tfma.ModelSpec(name='candidate', label_key='tips'),
tfma.ModelSpec(name='baseline', label_key='tips', is_baseline=True)
],
slicing_specs=[tfma.SlicingSpec()],
metrics_specs=[
tfma.MetricsSpec(metrics=[
tfma.MetricConfig(
class_name='BinaryAccuracy',
threshold=tfma.config.MetricThreshold(
value_threshold=tfma.GenericValueThreshold(
lower_bound={'value': 0.6}),
change_threshold=tfma.GenericChangeThreshold(
direction=tfma.MetricDirection.HIGHER_IS_BETTER,
absolute={'value': -1e-10})))
])
])
model_analyzer = Evaluator(
examples=example_gen.outputs['examples'],
model=trainer.outputs['model'],
baseline_model=model_resolver.outputs['model'],
# Change threshold will be ignored if there is no baseline (first run).
eval_config=eval_config)

# Checks whether the model passed the validation steps and pushes the model
# to a file destination if check passed.
pusher = Pusher(
model=trainer.outputs['model'],
model_blessing=model_analyzer.outputs['blessing'],
push_destination=pusher_pb2.PushDestination(
filesystem=pusher_pb2.PushDestination.Filesystem(
base_directory=serving_model_dir)))

return pipeline.Pipeline(
pipeline_name=pipeline_name,
pipeline_root=pipeline_root,
components=[
example_gen,
statistics_gen,
infer_schema,
validate_stats,
transform,
trainer,
model_resolver,
model_analyzer,
pusher,
],
enable_cache=True,
metadata_connection_config=metadata.sqlite_metadata_connection_config(
metadata_path),
# TODO(b/142684737): The multi-processing API might change.
beam_pipeline_args=['--direct_num_workers=%d' % direct_num_workers])


# To run this pipeline from the python CLI:
# $python taxi_pipeline_native_keras.py
if __name__ == '__main__':
absl.logging.set_verbosity(absl.logging.INFO)

BeamDagRunner().run(
_create_pipeline(
pipeline_name=_pipeline_name,
pipeline_root=_pipeline_root,
data_root=_data_root,
module_file=_module_file,
metadata_path=_metadata_path,
serving_model_dir=_serving_model_dir,
# 0 means auto-detect based on on the number of CPUs available during
# execution time.
direct_num_workers=0))
Loading

0 comments on commit dfcd914

Please sign in to comment.