forked from tensorflow/tfx
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Added a resolver that can resolve spans based on range_config
PiperOrigin-RevId: 358520190
- Loading branch information
1 parent
c2df96e
commit 85c365c
Showing
4 changed files
with
299 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,141 @@ | ||
# Copyright 2021 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Experimental Resolver for getting the artifacts based on Span.""" | ||
|
||
from typing import Dict, List, Optional, Text | ||
|
||
from tfx import types | ||
from tfx.components.example_gen import utils | ||
from tfx.dsl.components.common import resolver | ||
from tfx.orchestration import data_types | ||
from tfx.orchestration import metadata | ||
from tfx.proto import range_config_pb2 | ||
from tfx.types import artifact_utils | ||
|
||
|
||
class SpansResolver(resolver.ResolverStrategy): | ||
"""Resolver that return the artifacts based on Span. | ||
Note that this Resolver is experimental and is subject to change in terms of | ||
both interface and implementation. | ||
""" | ||
|
||
def __init__(self, range_config: range_config_pb2.RangeConfig): | ||
self._range_config = range_config | ||
|
||
def _resolve(self, input_dict: Dict[Text, List[types.Artifact]]): | ||
result = {} | ||
|
||
for k, artifact_list in input_dict.items(): | ||
in_range_artifacts = [] | ||
|
||
if self._range_config.HasField('static_range'): | ||
start_span_number = self._range_config.static_range.start_span_number | ||
end_span_number = self._range_config.static_range.end_span_number | ||
# Get the artifacts within range. | ||
for artifact in artifact_list: | ||
if not artifact.has_custom_property(utils.SPAN_PROPERTY_NAME): | ||
raise RuntimeError('Span does not exist for' % str(artifact)) | ||
span = int( | ||
artifact.get_string_custom_property(utils.SPAN_PROPERTY_NAME)) | ||
if span >= start_span_number and span <= end_span_number: | ||
in_range_artifacts.append(artifact) | ||
|
||
elif self._range_config.HasField('rolling_range'): | ||
start_span_number = self._range_config.rolling_range.start_span_number | ||
num_spans = self._range_config.rolling_range.num_spans | ||
if num_spans <= 0: | ||
raise ValueError('num_spans should be positive number.') | ||
most_recent_span = -1 | ||
# Get most recent span number. | ||
for artifact in artifact_list: | ||
if not artifact.has_custom_property(utils.SPAN_PROPERTY_NAME): | ||
raise RuntimeError('Span does not exist for' % str(artifact)) | ||
span = int( | ||
artifact.get_string_custom_property(utils.SPAN_PROPERTY_NAME)) | ||
if span > most_recent_span: | ||
most_recent_span = span | ||
|
||
start_span_number = max(start_span_number, | ||
most_recent_span - num_spans + 1) | ||
end_span_number = most_recent_span | ||
# Get the artifacts within range. | ||
for artifact in artifact_list: | ||
span = int( | ||
artifact.get_string_custom_property(utils.SPAN_PROPERTY_NAME)) | ||
if span >= start_span_number and span <= end_span_number: | ||
in_range_artifacts.append(artifact) | ||
|
||
else: | ||
raise ValueError('RangeConfig type is not supported.') | ||
|
||
result[k] = sorted( | ||
in_range_artifacts, | ||
key=lambda a: a.get_string_custom_property(utils.SPAN_PROPERTY_NAME), | ||
reverse=True) | ||
|
||
return result | ||
|
||
def resolve( | ||
self, | ||
pipeline_info: data_types.PipelineInfo, | ||
metadata_handler: metadata.Metadata, | ||
source_channels: Dict[Text, types.Channel], | ||
) -> resolver.ResolveResult: | ||
pipeline_context = metadata_handler.get_pipeline_context(pipeline_info) | ||
if pipeline_context is None: | ||
raise RuntimeError('Pipeline context absent for %s' % pipeline_context) | ||
|
||
candidate_dict = {} | ||
for k, c in source_channels.items(): | ||
candidate_artifacts = metadata_handler.get_qualified_artifacts( | ||
contexts=[pipeline_context], | ||
type_name=c.type_name, | ||
producer_component_id=c.producer_component_id, | ||
output_key=c.output_key) | ||
candidate_dict[k] = [ | ||
artifact_utils.deserialize_artifact(a.type, a.artifact) | ||
for a in candidate_artifacts | ||
] | ||
|
||
resolved_dict = self._resolve(candidate_dict) | ||
resolve_state_dict = { | ||
k: bool(artifact_list) for k, artifact_list in resolved_dict.items() | ||
} | ||
|
||
return resolver.ResolveResult( | ||
per_key_resolve_result=resolved_dict, | ||
per_key_resolve_state=resolve_state_dict) | ||
|
||
def resolve_artifacts( | ||
self, metadata_handler: metadata.Metadata, | ||
input_dict: Dict[Text, List[types.Artifact]] | ||
) -> Optional[Dict[Text, List[types.Artifact]]]: | ||
"""Resolves artifacts from channels by querying MLMD. | ||
Args: | ||
metadata_handler: A metadata handler to access MLMD store. | ||
input_dict: The input_dict to resolve from. | ||
Returns: | ||
If `min_count` for every input is met, returns a | ||
Dict[Text, List[Artifact]]. Otherwise, return None. | ||
Raises: | ||
RuntimeError: if input_dict contains artifact without span property. | ||
""" | ||
resolved_dict = self._resolve(input_dict) | ||
all_min_count_met = all( | ||
bool(artifact_list) for artifact_list in resolved_dict.values()) | ||
return resolved_dict if all_min_count_met else None |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
# Copyright 2021 Google LLC. All Rights Reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. | ||
"""Test for SpansResolver.""" | ||
|
||
from typing import Text | ||
# Standard Imports | ||
|
||
import tensorflow as tf | ||
from tfx import types | ||
from tfx.components.example_gen import utils | ||
from tfx.dsl.experimental import spans_resolver | ||
from tfx.orchestration import data_types | ||
from tfx.orchestration import metadata | ||
from tfx.proto import range_config_pb2 | ||
from tfx.types import standard_artifacts | ||
|
||
from ml_metadata.proto import metadata_store_pb2 | ||
|
||
|
||
class SpansResolverTest(tf.test.TestCase): | ||
|
||
def setUp(self): | ||
super(SpansResolverTest, self).setUp() | ||
self._connection_config = metadata_store_pb2.ConnectionConfig() | ||
self._connection_config.sqlite.SetInParent() | ||
self._pipeline_info = data_types.PipelineInfo( | ||
pipeline_name='my_pipeline', pipeline_root='/tmp', run_id='my_run_id') | ||
self._component_info = data_types.ComponentInfo( | ||
component_type='a.b.c', | ||
component_id='my_component', | ||
pipeline_info=self._pipeline_info) | ||
|
||
def _createExamples(self, span: Text) -> standard_artifacts.Examples: | ||
artifact = standard_artifacts.Examples() | ||
artifact.uri = 'uri' + span | ||
artifact.set_string_custom_property(utils.SPAN_PROPERTY_NAME, span) | ||
return artifact | ||
|
||
def testResolve(self): | ||
with metadata.Metadata(connection_config=self._connection_config) as m: | ||
contexts = m.register_pipeline_contexts_if_not_exists(self._pipeline_info) | ||
artifact_one = standard_artifacts.Examples() | ||
artifact_one.uri = 'uri_one' | ||
artifact_one.set_string_custom_property(utils.SPAN_PROPERTY_NAME, '1') | ||
m.publish_artifacts([artifact_one]) | ||
artifact_two = standard_artifacts.Examples() | ||
artifact_two.uri = 'uri_two' | ||
artifact_two.set_string_custom_property(utils.SPAN_PROPERTY_NAME, '2') | ||
m.register_execution( | ||
exec_properties={}, | ||
pipeline_info=self._pipeline_info, | ||
component_info=self._component_info, | ||
contexts=contexts) | ||
m.publish_execution( | ||
component_info=self._component_info, | ||
output_artifacts={'key': [artifact_one, artifact_two]}) | ||
|
||
resolver = spans_resolver.SpansResolver( | ||
range_config=range_config_pb2.RangeConfig( | ||
static_range=range_config_pb2.StaticRange( | ||
start_span_number=1, end_span_number=1))) | ||
resolve_result = resolver.resolve( | ||
pipeline_info=self._pipeline_info, | ||
metadata_handler=m, | ||
source_channels={ | ||
'input': | ||
types.Channel( | ||
type=artifact_one.type, | ||
producer_component_id=self._component_info.component_id, | ||
output_key='key') | ||
}) | ||
|
||
self.assertTrue(resolve_result.has_complete_result) | ||
self.assertEqual([ | ||
artifact.uri | ||
for artifact in resolve_result.per_key_resolve_result['input'] | ||
], [artifact_one.uri]) | ||
self.assertTrue(resolve_result.per_key_resolve_state['input']) | ||
|
||
def testResolveArtifacts(self): | ||
with metadata.Metadata(connection_config=self._connection_config) as m: | ||
artifact1 = self._createExamples('1') | ||
artifact2 = self._createExamples('2') | ||
artifact3 = self._createExamples('3') | ||
artifact4 = self._createExamples('4') | ||
artifact5 = self._createExamples('5') | ||
|
||
# Test StaticRange. | ||
resolver = spans_resolver.SpansResolver( | ||
range_config=range_config_pb2.RangeConfig( | ||
static_range=range_config_pb2.StaticRange( | ||
start_span_number=2, end_span_number=3))) | ||
result = resolver.resolve_artifacts( | ||
m, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) | ||
self.assertIsNotNone(result) | ||
self.assertEqual([a.uri for a in result['input']], | ||
[artifact3.uri, artifact2.uri]) | ||
|
||
# Test RollingRange. | ||
resolver = spans_resolver.SpansResolver( | ||
range_config=range_config_pb2.RangeConfig( | ||
rolling_range=range_config_pb2.RollingRange(num_spans=3))) | ||
result = resolver.resolve_artifacts( | ||
m, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) | ||
self.assertIsNotNone(result) | ||
self.assertEqual([a.uri for a in result['input']], | ||
[artifact5.uri, artifact4.uri, artifact3.uri]) | ||
|
||
# Test RollingRange with start_span_number. | ||
resolver = spans_resolver.SpansResolver( | ||
range_config=range_config_pb2.RangeConfig( | ||
rolling_range=range_config_pb2.RollingRange( | ||
start_span_number=4, num_spans=3))) | ||
result = resolver.resolve_artifacts( | ||
m, {'input': [artifact1, artifact2, artifact3, artifact4, artifact5]}) | ||
self.assertIsNotNone(result) | ||
self.assertEqual([a.uri for a in result['input']], | ||
[artifact5.uri, artifact4.uri]) | ||
|
||
|
||
if __name__ == '__main__': | ||
tf.test.main() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters