Skip to content

Commit

Permalink
Removes deprecated inputs_utils.resolve_input_artifacts.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 446350768
  • Loading branch information
chongkong authored and tfx-copybara committed May 4, 2022
1 parent 2df92b3 commit 800367e
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 125 deletions.
55 changes: 1 addition & 54 deletions tfx/orchestration/portable/inputs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Portable library for input artifacts resolution."""
from typing import Dict, Iterable, List, Optional, Mapping, Sequence, Union, cast
from typing import Dict, Iterable, List, Optional, Mapping, Sequence, Union

from absl import logging
from tfx import types
Expand All @@ -24,7 +24,6 @@
from tfx.orchestration.portable.mlmd import execution_lib
from tfx.proto.orchestration import pipeline_pb2
from tfx.types import artifact_utils
from tfx.utils import deprecation_utils
from tfx.utils import typing_utils

import ml_metadata as mlmd
Expand Down Expand Up @@ -141,58 +140,6 @@ def _is_sufficient(artifact_multimap: Mapping[str, Sequence[types.Artifact]],
if key in node_inputs.inputs)


@deprecation_utils.deprecated(
'2021-06-01', 'Use resolve_input_artifacts_v2() instead.')
def resolve_input_artifacts(
metadata_handler: metadata.Metadata, node_inputs: pipeline_pb2.NodeInputs
) -> Optional[typing_utils.ArtifactMultiMap]:
"""Resolves input artifacts of a pipeline node.
Args:
metadata_handler: A metadata handler to access MLMD store.
node_inputs: A pipeline_pb2.NodeInputs message that instructs artifact
resolution for a pipeline node.
Returns:
If `min_count` for every input is met, returns a Dict[str, List[Artifact]].
Otherwise, return None.
"""
initial_dict = _resolve_initial_dict(metadata_handler, node_inputs)
if not _is_sufficient(initial_dict, node_inputs):
min_counts = {key: input_spec.min_count
for key, input_spec in node_inputs.inputs.items()}
logging.warning('Resolved inputs should have %r artifacts, but got %r.',
min_counts, initial_dict)
return None

try:
result = processor.run_resolver_steps(
initial_dict,
resolver_steps=node_inputs.resolver_config.resolver_steps,
store=metadata_handler.store)
except exceptions.InputResolutionError:
# If ResolverStrategy has returned None in the middle, InputResolutionError
# is raised. Legacy input resolution has returned None in this case.
return None
except exceptions.SkipSignal:
# SkipSignal is not fully representable return value in legacy input
# resolution, but None is the best effort.
return None

if typing_utils.is_list_of_artifact_multimap(result):
result = cast(Sequence[typing_utils.ArtifactMultiMap], result)
if len(result) != 1:
raise ValueError(
'Invalid number of resolved inputs; expected 1 but got '
f'{len(result)}: {result}')
return result[0]
elif typing_utils.is_artifact_multimap(result):
return cast(typing_utils.ArtifactMultiMap, result)
else:
raise TypeError(f'Invalid input resolution result: {result}. Should be '
'Mapping[str, Sequence[Artifact]].')


class Trigger(tuple, Sequence[typing_utils.ArtifactMultiMap]):
"""Input resolution result of list of dict."""

Expand Down
87 changes: 24 additions & 63 deletions tfx/orchestration/portable/inputs_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def testResolveParametersFail(self):
with self.assertRaisesRegex(RuntimeError, 'Parameter value not ready'):
inputs_utils.resolve_parameters(parameters)

def testResolveInputsArtifacts(self):
def testResolveInputArtifactsV2(self):
pipeline = self.load_pipeline_proto(
'pipeline_for_input_resolver_test.pbtxt')
my_example_gen = pipeline.nodes[0].pipeline_node
Expand Down Expand Up @@ -234,17 +234,18 @@ def testResolveInputsArtifacts(self):

# Gets inputs for transform. Should get back what the first ExampleGen
# published in the `output_examples` channel.
transform_inputs = inputs_utils.resolve_input_artifacts(
m, my_transform.inputs)
transform_inputs = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_transform)[0]
self.assertArtifactMapEqual({'examples_1': [output_example],
'examples_2': [output_example]},
transform_inputs)

# Tries to resolve inputs for trainer. As trainer also requires min_count
# for both input channels (from example_gen and from transform) but we did
# not publish anything from transform, it should return nothing.
self.assertIsNone(
inputs_utils.resolve_input_artifacts(m, my_trainer.inputs))
with self.assertRaises(exceptions.FailedPreconditionError):
inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_trainer)

# Tries to resolve inputs for transform after adding a new context query
# to the input spec that refers to a non-existent context. Inputs cannot
Expand All @@ -253,11 +254,11 @@ def testResolveInputsArtifacts(self):
0].context_queries.add()
context_query.type.name = 'non_existent_context'
context_query.name.field_value.string_value = 'non_existent_context'
transform_inputs = inputs_utils.resolve_input_artifacts(
m, my_transform.inputs)
self.assertIsNone(transform_inputs)
with self.assertRaises(exceptions.FailedPreconditionError):
inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_transform)

def testResolverWithLatestArtifactStrategy(self):
def testResolveInputArtifactsV2_LatestArtifactStrategy(self):
pipeline = self.load_pipeline_proto(
'pipeline_for_input_resolver_test.pbtxt')
my_example_gen = pipeline.nodes[0].pipeline_node
Expand All @@ -284,13 +285,13 @@ def testResolverWithLatestArtifactStrategy(self):

# Gets inputs for transform. Should get back what the first ExampleGen
# published in the `output_examples` channel.
transform_inputs = inputs_utils.resolve_input_artifacts(
m, my_transform.inputs)
transform_inputs = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_transform)[0]
self.assertArtifactMapEqual({'examples_1': [output_example_2],
'examples_2': [output_example_2]},
transform_inputs)

def testResolveInputArtifactsOutputKeyUnset(self):
def testResolveInputArtifactsV2_OutputKeyUnset(self):
pipeline = self.load_pipeline_proto(
'pipeline_for_input_resolver_test_output_key_unset.pbtxt')
my_trainer = pipeline.nodes[0].pipeline_node
Expand All @@ -306,8 +307,8 @@ def testResolveInputArtifactsOutputKeyUnset(self):

# Gets inputs for pusher. Should get back what the first Model
# published in the `output_model` channel.
pusher_inputs = inputs_utils.resolve_input_artifacts(
m, my_pusher.inputs)
pusher_inputs = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_pusher)[0]
self.assertArtifactMapEqual({'model': [output_model]},
pusher_inputs)

Expand All @@ -332,41 +333,6 @@ def _append_resolver_step(self, node_pb, cls, config=None):
step_pb.class_path = name_utils.get_full_name(cls)
step_pb.config_json = json_utils.dumps(config or {})

def testResolveInputArtifacts_SkippingStrategy(self):
self._setup_pipeline_for_input_resolver_test()
self._append_resolver_step(self._my_transform, SkippingStrategy)

result = inputs_utils.resolve_input_artifacts(
self._metadata_handler, self._my_transform.inputs)
self.assertIsNone(result)

def testResolveInputArtifacts_NonDictOutput(self):
self._setup_pipeline_for_input_resolver_test()
self._append_resolver_step(self._my_transform, BadOutputStrategy)

with self.assertRaisesRegex(TypeError, 'Invalid input resolution result'):
inputs_utils.resolve_input_artifacts(
self._metadata_handler, self._my_transform.inputs)

def testResolveInputArtifacts_NonDictArg(self):
self._setup_pipeline_for_input_resolver_test()
self._append_resolver_step(self._my_transform, DuplicateOp)
self._append_resolver_step(self._my_transform, IdentityStrategy)

with self.assertRaisesRegex(TypeError, 'Invalid argument type'):
inputs_utils.resolve_input_artifacts(
self._metadata_handler, self._my_transform.inputs)

def testResolveInputArtifacts_MixedStrategyAndOp(self):
self._setup_pipeline_for_input_resolver_test()
self._append_resolver_step(self._my_transform, IdentityStrategy)
self._append_resolver_step(self._my_transform, IdentityOp)

result = inputs_utils.resolve_input_artifacts(
self._metadata_handler, self._my_transform.inputs)
self.assertArtifactMapEqual({'examples_1': self._examples,
'examples_2': self._examples}, result)

def testResolveInputArtifactsV2_Normal(self):
self._setup_pipeline_for_input_resolver_test()

Expand Down Expand Up @@ -523,9 +489,8 @@ def testLatestUnprocessedArtifacts(self):
output_map={'output_examples': [ex2]})
ex2 = output_artifacts['output_examples'][0]

result = inputs_utils.resolve_input_artifacts(
metadata_handler=m,
node_inputs=my_transform.inputs)
result = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_transform)[0]

self.assertArtifactMapEqual({'examples_1': [ex2],
'examples_2': [ex2]}, result)
Expand Down Expand Up @@ -568,9 +533,8 @@ def testLatestUnprocessedArtifacts_IgnoreAlreadyProcessed(self):
m, my_transform, input_map={'examples_1': [ex2],
'examples_2': [ex2]}, output_map=None)

result = inputs_utils.resolve_input_artifacts(
metadata_handler=m,
node_inputs=my_transform.inputs)
result = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_transform)[0]

self.assertArtifactMapEqual({'examples_1': [ex1],
'examples_2': [ex1]}, result)
Expand Down Expand Up @@ -616,11 +580,9 @@ def testLatestUnprocessedArtifacts_NoneIfEverythingProcessed(self):
input_map={'examples_1': [ex2], 'examples_2': [ex2]},
output_map=None)

result = inputs_utils.resolve_input_artifacts(
metadata_handler=m,
node_inputs=my_transform.inputs)

self.assertIsNone(result)
with self.assertRaises(exceptions.InputResolutionError):
inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_transform)

def testLatestArtifacts_withInputKeys(self):
pipeline = self.load_pipeline_proto(
Expand Down Expand Up @@ -666,9 +628,8 @@ def testLatestArtifacts_withInputKeys(self):
input_map={'examples_1': [ex2], 'examples_2': [ex2]},
output_map={'transform_graph': [tf2]})
tf2 = output_artifacts['transform_graph'][0]
result = inputs_utils.resolve_input_artifacts(
metadata_handler=m,
node_inputs=my_trainer.inputs)
result = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=my_trainer)[0]

# "examples" input channel doesn't go through the resolver and its order is
# undeterministic. Sort artifacts by ID for testing convenience.
Expand Down
13 changes: 7 additions & 6 deletions tfx/orchestration/portable/launcher_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ class _FakeExampleGenLikeDriver(base_driver.BaseDriver):

def __init__(self, mlmd_connection: metadata.Metadata):
super().__init__(mlmd_connection)
self._self_output = text_format.Parse(
node_inputs = text_format.Parse(
"""
inputs {
key: "examples"
Expand Down Expand Up @@ -173,17 +173,18 @@ def __init__(self, mlmd_connection: metadata.Metadata):
}
output_key: "output_examples"
}
min_count: 1
min_count: 0
}
}""", pipeline_pb2.NodeInputs())
self._pipeline_node = pipeline_pb2.PipelineNode(inputs=node_inputs)

def run(self, execution_info) -> driver_output_pb2.DriverOutput:
# Fake a constant span number, which, on prod, is usually calculated based
# on date.
span = 2
with self._mlmd_connection as m:
previous_output = inputs_utils.resolve_input_artifacts(
m, self._self_output)
previous_output = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=self._pipeline_node)[0]

# Version should be the max of existing version + 1 if span exists,
# otherwise 0.
Expand Down Expand Up @@ -710,8 +711,8 @@ def create_test_launcher(executor_operators):
exec_properties = data_types_utils.build_parsed_value_dict(
inputs_utils.resolve_parameters_with_schema(
node_parameters=test_launcher._pipeline_node.parameters))
input_artifacts = inputs_utils.resolve_input_artifacts(
metadata_handler=m, node_inputs=test_launcher._pipeline_node.inputs)
input_artifacts = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=test_launcher._pipeline_node)[0]
first_execution = test_launcher._register_or_reuse_execution(
metadata_handler=m,
contexts=contexts,
Expand Down
4 changes: 2 additions & 2 deletions tfx/orchestration/portable/resolver_node_handler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,8 +142,8 @@ def testRun_ExecutionCompleted(self):
}
upstream_nodes: "my_resolver"
""", pipeline_pb2.PipelineNode())
downstream_input_artifacts = inputs_utils.resolve_input_artifacts(
metadata_handler=m, node_inputs=down_stream_node.inputs)
downstream_input_artifacts = inputs_utils.resolve_input_artifacts_v2(
metadata_handler=m, pipeline_node=down_stream_node)[0]
downstream_input_model = downstream_input_artifacts['input_models']
self.assertLen(downstream_input_model, 1)
self.assertProtoPartiallyEquals(
Expand Down

0 comments on commit 800367e

Please sign in to comment.