Skip to content

Commit

Permalink
Allow providing StatsOptions and Schema to StatisticsGen component.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 299460239
  • Loading branch information
embr authored and tensorflow-extended-team committed Mar 7, 2020
1 parent f9e70b3 commit 6f4eee8
Show file tree
Hide file tree
Showing 9 changed files with 204 additions and 11 deletions.
3 changes: 3 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
# Current version (not yet released; still in development)

## Major Features and Improvements
* Updated `StatisticsGen` to optionally consume a schema `Artifact`.
* Added support for configuring the `StatisticsGen` component via serializable
parts of `StatsOptions`.

## Bug fixes and other changes
* Fix the behavior of Trainer Tensorboard visualization when caching is used.
Expand Down
29 changes: 29 additions & 0 deletions docs/guide/statsgen.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,3 +29,32 @@ compute_eval_stats = components.StatisticsGen(
name='compute-eval-stats'
)
```

## Using the StatsGen Component With a Schema

For the first run of a pipeline, the output of StatisticsGen will be used to
infer a schema. However, on subsequent runs you may have a manually curated
schema that contains additional information about your data set. By providing
this schema to StatisticsGen, TFDV can provide more useful statistics based on
declared properties of your data set.

In this setting, you will invoke StatisticsGen with a curated schema that has
been imported by an ImporterNode like this:

```python
from tfx import components
from tfx.types import standard_artifacts

...

user_schema_importer = components.ImporterNode(
instance_name='import_user_schema',
source_uri=user_schema_path,
artifact_type=standard_artifcats.Schema)

compute_eval_stats = components.StatisticsGen(
examples=example_gen.outputs['examples'],
schema=user_schema_importer.outputs['result'],
name='compute-eval-stats'
)
```
18 changes: 17 additions & 1 deletion tfx/components/statistics_gen/component.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from typing import Optional, Text

import absl
import tensorflow_data_validation as tfdv

from tfx import types
from tfx.components.base import base_component
Expand Down Expand Up @@ -53,6 +54,8 @@ class StatisticsGen(base_component.BaseComponent):

def __init__(self,
examples: types.Channel = None,
schema: Optional[types.Channel] = None,
stats_options: Optional[tfdv.StatsOptions] = None,
output: Optional[types.Channel] = None,
input_data: Optional[types.Channel] = None,
instance_name: Optional[Text] = None):
Expand All @@ -62,6 +65,13 @@ def __init__(self,
examples: A Channel of `ExamplesPath` type, likely generated by the
[ExampleGen component](https://www.tensorflow.org/tfx/guide/examplegen).
This needs to contain two splits labeled `train` and `eval`. _required_
schema: A `Schema` channel to use for automatically configuring the value
of stats options passed to TFDV.
stats_options: The StatsOptions instance to configure optional TFDV
behavior. When stats_options.schema is set, it will be used instead of
the `schema` channel input. Due to the requirement that stats_options be
serialized, the slicer functions and custom stats generators are dropped
and are therefore not usable.
output: `ExampleStatisticsPath` channel for statistics of each split
provided in the input examples.
input_data: Backwards compatibility alias for the `examples` argument.
Expand All @@ -82,5 +92,11 @@ def __init__(self,
output = types.Channel(
type=standard_artifacts.ExampleStatistics,
artifacts=[statistics_artifact])
spec = StatisticsGenSpec(examples=examples, statistics=output)
# TODO(b/150802589): Move jsonable interface to tfx_bsl and use json_utils.
stats_options_json = stats_options.to_json() if stats_options else None
spec = StatisticsGenSpec(
examples=examples,
schema=schema,
stats_options_json=stats_options_json,
statistics=output)
super(StatisticsGen, self).__init__(spec=spec, instance_name=instance_name)
20 changes: 20 additions & 0 deletions tfx/components/statistics_gen/component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from __future__ import print_function

import tensorflow as tf
import tensorflow_data_validation as tfdv
from tfx.components.statistics_gen import component
from tfx.types import artifact_utils
from tfx.types import channel_utils
Expand All @@ -35,6 +36,25 @@ def testConstruct(self):
self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
statistics_gen.outputs['statistics'].type_name)

def testConstructWithSchemaAndStatsOptions(self):
examples = standard_artifacts.Examples()
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])
schema = standard_artifacts.Schema()
stats_options = tfdv.StatsOptions(
weight_feature='weight',
generators=[ # generators should be dropped
tfdv.LiftStatsGenerator(
schema=None,
y_path=tfdv.FeaturePath(['label']),
x_paths=[tfdv.FeaturePath(['feature'])])
])
statistics_gen = component.StatisticsGen(
examples=channel_utils.as_channel([examples]),
schema=channel_utils.as_channel([schema]),
stats_options=stats_options)
self.assertEqual(standard_artifacts.ExampleStatistics.TYPE_NAME,
statistics_gen.outputs['statistics'].type_name)


if __name__ == '__main__':
tf.test.main()
40 changes: 35 additions & 5 deletions tfx/components/statistics_gen/executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,10 +34,14 @@
from tfx.utils import io_utils


# Key for examples in executor input_dict.
# Keys for input_dict.
EXAMPLES_KEY = 'examples'
SCHEMA_KEY = 'schema'

# Key for output statistics in executor output_dict.
# Keys for exec_properties dict.
STATS_OPTIONS_JSON_KEY = 'stats_options_json'

# Keys for output_dict
STATISTICS_KEY = 'statistics'

# Default file name for stats generated.
Expand All @@ -64,24 +68,50 @@ def Do(self, input_dict: Dict[Text, List[types.Artifact]],
input_dict: Input dict from input key to a list of Artifacts.
- input_data: A list of type `standard_artifacts.Examples`. This should
contain both 'train' and 'eval' split.
- schema: Optionally, a list of type `standard_artifacts.Schema`. When
the stats_options exec_property also contains a schema, this input
should not be provided.
output_dict: Output dict from output key to a list of Artifacts.
- output: A list of type `standard_artifacts.ExampleStatistics`. This
should contain both the 'train' and 'eval' splits.
exec_properties: A dict of execution properties. Not used yet.
exec_properties: A dict of execution properties.
- stats_options_json: Optionally, a JSON representation of StatsOptions.
When a schema is provided as an input, the StatsOptions value should
not also contain a schema.
Raises:
ValueError when a schema is provided both as an input and as part of the
StatsOptions exec_property.
Returns:
None
"""
self._log_startup(input_dict, output_dict, exec_properties)

stats_options = options.StatsOptions()
if STATS_OPTIONS_JSON_KEY in exec_properties:
stats_options_json = exec_properties[STATS_OPTIONS_JSON_KEY]
if stats_options_json:
# TODO(b/150802589): Move jsonable interface to tfx_bsl and use
# json_utils
stats_options = options.StatsOptions.from_json(stats_options_json)
if input_dict.get(SCHEMA_KEY):
if stats_options.schema:
raise ValueError('A schema was provided as an input and the '
'stats_options exec_property also contains a schema '
'value. At most one of these may be set.')
else:
schema = io_utils.SchemaReader().read(
io_utils.get_only_uri_in_dir(
artifact_utils.get_single_uri(input_dict[SCHEMA_KEY])))
stats_options.schema = schema

split_uris = []
for artifact in input_dict[EXAMPLES_KEY]:
for split in artifact_utils.decode_split_names(artifact.split_names):
uri = os.path.join(artifact.uri, split)
split_uris.append((split, uri))
with self._make_beam_pipeline() as p:
# TODO(b/126263006): Support more stats_options through config.
stats_options = options.StatsOptions()
for split, uri in split_uris:
absl.logging.info('Generating statistics for split {}'.format(split))
input_uri = io_utils.all_files_pattern(uri)
Expand Down
91 changes: 89 additions & 2 deletions tfx/components/statistics_gen/executor_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
from absl.testing import absltest
import tensorflow as tf
import tensorflow_data_validation as tfdv
from tensorflow_metadata.proto.v0 import schema_pb2
from tfx.components.statistics_gen import executor
from tfx.types import artifact_utils
from tfx.types import standard_artifacts
Expand Down Expand Up @@ -71,15 +72,101 @@ def testDo(self):
}

# Run executor.
evaluator = executor.Executor()
evaluator.Do(input_dict, output_dict, exec_properties={})
stats_gen_executor = executor.Executor()
stats_gen_executor.Do(input_dict, output_dict, exec_properties={})

# Check statistics_gen outputs.
self._validate_stats_output(
os.path.join(stats.uri, 'train', 'stats_tfrecord'))
self._validate_stats_output(
os.path.join(stats.uri, 'eval', 'stats_tfrecord'))

def testDoWithSchemaAndStatsOptions(self):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName)
tf.io.gfile.makedirs(output_data_dir)

# Create input dict.
examples = standard_artifacts.Examples()
examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])

schema = standard_artifacts.Schema()
schema.uri = os.path.join(source_data_dir, 'schema_gen')

input_dict = {
executor.EXAMPLES_KEY: [examples],
executor.SCHEMA_KEY: [schema]
}

exec_properties = {
executor.STATS_OPTIONS_JSON_KEY:
tfdv.StatsOptions(label_feature='company').to_json(),
}

# Create output dict.
stats = standard_artifacts.ExampleStatistics()
stats.uri = output_data_dir
stats.split_names = artifact_utils.encode_split_names(['train', 'eval'])
output_dict = {
executor.STATISTICS_KEY: [stats],
}

# Run executor.
stats_gen_executor = executor.Executor()
stats_gen_executor.Do(
input_dict, output_dict, exec_properties=exec_properties)

# Check statistics_gen outputs.
self._validate_stats_output(
os.path.join(stats.uri, 'train', 'stats_tfrecord'))
self._validate_stats_output(
os.path.join(stats.uri, 'eval', 'stats_tfrecord'))

def testDoWithTwoSchemas(self):
source_data_dir = os.path.join(
os.path.dirname(os.path.dirname(__file__)), 'testdata')
output_data_dir = os.path.join(
os.environ.get('TEST_UNDECLARED_OUTPUTS_DIR', self.get_temp_dir()),
self._testMethodName)
tf.io.gfile.makedirs(output_data_dir)

# Create input dict.
examples = standard_artifacts.Examples()
examples.uri = os.path.join(source_data_dir, 'csv_example_gen')
examples.split_names = artifact_utils.encode_split_names(['train', 'eval'])

schema = standard_artifacts.Schema()
schema.uri = os.path.join(source_data_dir, 'schema_gen')

input_dict = {
executor.EXAMPLES_KEY: [examples],
executor.SCHEMA_KEY: [schema]
}

exec_properties = {
executor.STATS_OPTIONS_JSON_KEY:
tfdv.StatsOptions(label_feature='company',
schema=schema_pb2.Schema()).to_json(),
}

# Create output dict.
stats = standard_artifacts.ExampleStatistics()
stats.uri = output_data_dir
stats.split_names = artifact_utils.encode_split_names(['train', 'eval'])
output_dict = {
executor.STATISTICS_KEY: [stats],
}

# Run executor.
stats_gen_executor = executor.Executor()
with self.assertRaises(ValueError):
stats_gen_executor.Do(
input_dict, output_dict, exec_properties=exec_properties)


if __name__ == '__main__':
absltest.main()
4 changes: 3 additions & 1 deletion tfx/orchestration/kubeflow/testdata/component.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
"__class__": "NodeWrapper",
"__module__": "tfx.orchestration.kubeflow.node_wrapper",
"__tfx_object_type__": "jsonable",
"_exec_properties": {},
"_exec_properties": {
"stats_options_json": null
},
"_id": "StatisticsGen.foo",
"_inputs": {
"__class__": "_PropertyDictWrapper",
Expand Down
4 changes: 3 additions & 1 deletion tfx/orchestration/kubeflow/testdata/statistics_gen.json
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
"__class__": "NodeWrapper",
"__module__": "tfx.orchestration.kubeflow.node_wrapper",
"__tfx_object_type__": "jsonable",
"_exec_properties": {},
"_exec_properties": {
"stats_options_json": null
},
"_id": "StatisticsGen.foo",
"_inputs": {
"__class__": "_PropertyDictWrapper",
Expand Down
6 changes: 5 additions & 1 deletion tfx/types/standard_component_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,13 @@ class SchemaGenSpec(ComponentSpec):
class StatisticsGenSpec(ComponentSpec):
"""StatisticsGen component spec."""

PARAMETERS = {}
PARAMETERS = {
'stats_options_json':
ExecutionParameter(type=(str, Text), optional=True),
}
INPUTS = {
'examples': ChannelParameter(type=standard_artifacts.Examples),
'schema': ChannelParameter(type=standard_artifacts.Schema, optional=True),
}
OUTPUTS = {
'statistics': ChannelParameter(type=standard_artifacts.ExampleStatistics),
Expand Down

0 comments on commit 6f4eee8

Please sign in to comment.