Skip to content

Commit

Permalink
Added a resolver that can resolve spans based on range_config
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 358520190
  • Loading branch information
tfx-copybara committed Feb 20, 2021
1 parent c2df96e commit 85c365c
Show file tree
Hide file tree
Showing 4 changed files with 299 additions and 1 deletion.
2 changes: 2 additions & 0 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
at `tfx/examples/ranking`. More documentation will be available in future
releases.

* Added a resolver that can resolve spans based on range_config.

## Breaking changes

### For pipeline authors
Expand Down
141 changes: 141 additions & 0 deletions tfx/dsl/experimental/spans_resolver.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Experimental Resolver for getting the artifacts based on Span."""

from typing import Dict, List, Optional, Text

from tfx import types
from tfx.components.example_gen import utils
from tfx.dsl.components.common import resolver
from tfx.orchestration import data_types
from tfx.orchestration import metadata
from tfx.proto import range_config_pb2
from tfx.types import artifact_utils


class SpansResolver(resolver.ResolverStrategy):
"""Resolver that return the artifacts based on Span.
Note that this Resolver is experimental and is subject to change in terms of
both interface and implementation.
"""

def __init__(self, range_config: range_config_pb2.RangeConfig):
self._range_config = range_config

def _resolve(self, input_dict: Dict[Text, List[types.Artifact]]):
result = {}

for k, artifact_list in input_dict.items():
in_range_artifacts = []

if self._range_config.HasField('static_range'):
start_span_number = self._range_config.static_range.start_span_number
end_span_number = self._range_config.static_range.end_span_number
# Get the artifacts within range.
for artifact in artifact_list:
if not artifact.has_custom_property(utils.SPAN_PROPERTY_NAME):
raise RuntimeError('Span does not exist for' % str(artifact))
span = int(
artifact.get_string_custom_property(utils.SPAN_PROPERTY_NAME))
if span >= start_span_number and span <= end_span_number:
in_range_artifacts.append(artifact)

elif self._range_config.HasField('rolling_range'):
start_span_number = self._range_config.rolling_range.start_span_number
num_spans = self._range_config.rolling_range.num_spans
if num_spans <= 0:
raise ValueError('num_spans should be positive number.')
most_recent_span = -1
# Get most recent span number.
for artifact in artifact_list:
if not artifact.has_custom_property(utils.SPAN_PROPERTY_NAME):
raise RuntimeError('Span does not exist for' % str(artifact))
span = int(
artifact.get_string_custom_property(utils.SPAN_PROPERTY_NAME))
if span > most_recent_span:
most_recent_span = span

start_span_number = max(start_span_number,
most_recent_span - num_spans + 1)
end_span_number = most_recent_span
# Get the artifacts within range.
for artifact in artifact_list:
span = int(
artifact.get_string_custom_property(utils.SPAN_PROPERTY_NAME))
if span >= start_span_number and span <= end_span_number:
in_range_artifacts.append(artifact)

else:
raise ValueError('RangeConfig type is not supported.')

result[k] = sorted(
in_range_artifacts,
key=lambda a: a.get_string_custom_property(utils.SPAN_PROPERTY_NAME),
reverse=True)

return result

def resolve(
self,
pipeline_info: data_types.PipelineInfo,
metadata_handler: metadata.Metadata,
source_channels: Dict[Text, types.Channel],
) -> resolver.ResolveResult:
pipeline_context = metadata_handler.get_pipeline_context(pipeline_info)
if pipeline_context is None:
raise RuntimeError('Pipeline context absent for %s' % pipeline_context)

candidate_dict = {}
for k, c in source_channels.items():
candidate_artifacts = metadata_handler.get_qualified_artifacts(
contexts=[pipeline_context],
type_name=c.type_name,
producer_component_id=c.producer_component_id,
output_key=c.output_key)
candidate_dict[k] = [
artifact_utils.deserialize_artifact(a.type, a.artifact)
for a in candidate_artifacts
]

resolved_dict = self._resolve(candidate_dict)
resolve_state_dict = {
k: bool(artifact_list) for k, artifact_list in resolved_dict.items()
}

return resolver.ResolveResult(
per_key_resolve_result=resolved_dict,
per_key_resolve_state=resolve_state_dict)

def resolve_artifacts(
self, metadata_handler: metadata.Metadata,
input_dict: Dict[Text, List[types.Artifact]]
) -> Optional[Dict[Text, List[types.Artifact]]]:
"""Resolves artifacts from channels by querying MLMD.
Args:
metadata_handler: A metadata handler to access MLMD store.
input_dict: The input_dict to resolve from.
Returns:
If `min_count` for every input is met, returns a
Dict[Text, List[Artifact]]. Otherwise, return None.
Raises:
RuntimeError: if input_dict contains artifact without span property.
"""
resolved_dict = self._resolve(input_dict)
all_min_count_met = all(
bool(artifact_list) for artifact_list in resolved_dict.values())
return resolved_dict if all_min_count_met else None
133 changes: 133 additions & 0 deletions tfx/dsl/experimental/spans_resolver_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,133 @@
# Copyright 2021 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test for SpansResolver."""

from typing import Text
# Standard Imports

import tensorflow as tf
from tfx import types
from tfx.components.example_gen import utils
from tfx.dsl.experimental import spans_resolver
from tfx.orchestration import data_types
from tfx.orchestration import metadata
from tfx.proto import range_config_pb2
from tfx.types import standard_artifacts

from ml_metadata.proto import metadata_store_pb2


class SpansResolverTest(tf.test.TestCase):

def setUp(self):
super(SpansResolverTest, self).setUp()
self._connection_config = metadata_store_pb2.ConnectionConfig()
self._connection_config.sqlite.SetInParent()
self._pipeline_info = data_types.PipelineInfo(
pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id')
self._component_info = data_types.ComponentInfo(
component_type='a.b.c',
component_id='my_component',
pipeline_info=self._pipeline_info)

def _createExamples(self, span: Text) -> standard_artifacts.Examples:
artifact = standard_artifacts.Examples()
artifact.uri = 'uri' + span
artifact.set_string_custom_property(utils.SPAN_PROPERTY_NAME, span)
return artifact

def testResolve(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
contexts = m.register_pipeline_contexts_if_not_exists(self._pipeline_info)
artifact_one = standard_artifacts.Examples()
artifact_one.uri = 'uri_one'
artifact_one.set_string_custom_property(utils.SPAN_PROPERTY_NAME, '1')
m.publish_artifacts([artifact_one])
artifact_two = standard_artifacts.Examples()
artifact_two.uri = 'uri_two'
artifact_two.set_string_custom_property(utils.SPAN_PROPERTY_NAME, '2')
m.register_execution(
exec_properties={},
pipeline_info=self._pipeline_info,
component_info=self._component_info,
contexts=contexts)
m.publish_execution(
component_info=self._component_info,
output_artifacts={'key': [artifact_one, artifact_two]})

resolver = spans_resolver.SpansResolver(
range_config=range_config_pb2.RangeConfig(
static_range=range_config_pb2.StaticRange(
start_span_number=1, end_span_number=1)))
resolve_result = resolver.resolve(
pipeline_info=self._pipeline_info,
metadata_handler=m,
source_channels={
'input':
types.Channel(
type=artifact_one.type,
producer_component_id=self._component_info.component_id,
output_key='key')
})

self.assertTrue(resolve_result.has_complete_result)
self.assertEqual([
artifact.uri
for artifact in resolve_result.per_key_resolve_result['input']
], [artifact_one.uri])
self.assertTrue(resolve_result.per_key_resolve_state['input'])

def testResolveArtifacts(self):
with metadata.Metadata(connection_config=self._connection_config) as m:
artifact1 = self._createExamples('1')
artifact2 = self._createExamples('2')
artifact3 = self._createExamples('3')
artifact4 = self._createExamples('4')
artifact5 = self._createExamples('5')

# Test StaticRange.
resolver = spans_resolver.SpansResolver(
range_config=range_config_pb2.RangeConfig(
static_range=range_config_pb2.StaticRange(
start_span_number=2, end_span_number=3)))
result = resolver.resolve_artifacts(
m, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
self.assertIsNotNone(result)
self.assertEqual([a.uri for a in result['input']],
[artifact3.uri, artifact2.uri])

# Test RollingRange.
resolver = spans_resolver.SpansResolver(
range_config=range_config_pb2.RangeConfig(
rolling_range=range_config_pb2.RollingRange(num_spans=3)))
result = resolver.resolve_artifacts(
m, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
self.assertIsNotNone(result)
self.assertEqual([a.uri for a in result['input']],
[artifact5.uri, artifact4.uri, artifact3.uri])

# Test RollingRange with start_span_number.
resolver = spans_resolver.SpansResolver(
range_config=range_config_pb2.RangeConfig(
rolling_range=range_config_pb2.RollingRange(
start_span_number=4, num_spans=3)))
result = resolver.resolve_artifacts(
m, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]})
self.assertIsNotNone(result)
self.assertEqual([a.uri for a in result['input']],
[artifact5.uri, artifact4.uri])


if __name__ == '__main__':
tf.test.main()
24 changes: 23 additions & 1 deletion tfx/proto/range_config.proto
Original file line number Diff line number Diff line change
Expand Up @@ -23,13 +23,35 @@ message StaticRange {
int32 end_span_number = 2;
}

// Describes a rolling range:
// [most_recent_span - num_spans + 1,
// most_recent_span].
// For example, say you want the range to include only the latest span,
// the appropriate RollingRange would simply be:
// RollingRange <
// num_spans = 1
// >
// The range is clipped based on available data.
// Note that num_spans is required in RollingRange, while others are optional.
message RollingRange {
// Starting span before which no span will be considered.
// This is useful to clip the range in case the user
// wants to start front-filling some feature column after a certain date.
int32 start_span_number = 1;
// Length of the range.
int32 num_spans = 2;

reserved 3;
}

// RangeConfig is an abstract proto which can be used to describe ranges
// for different entities in TFX Pipeline. All indices corespond to increasing
// span numbers starting from the initial span at index 0.
message RangeConfig {
oneof range {
StaticRange static_range = 1;
RollingRange rolling_range = 2;
}

reserved 2, 3, 4, 5;
reserved 3, 4, 5;
}

0 comments on commit 85c365c

Please sign in to comment.