Skip to content

Commit

Permalink
Add non-AritfactMultiMap data types for ResolverOps.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 394635470
  • Loading branch information
chongkong authored and tfx-copybara committed Sep 3, 2021
1 parent 7c2e78f commit 7497850
Show file tree
Hide file tree
Showing 4 changed files with 90 additions and 6 deletions.
4 changes: 4 additions & 0 deletions tfx/dsl/components/common/resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,13 @@ class ResolverStrategy(abc.ABC):
@classmethod
def as_resolver_op(cls, input_node: resolver_op.OpNode, **kwargs):
"""ResolverOp-like usage inside resolver_function."""
if input_node.output_data_type != resolver_op.DataTypes.ARTIFACT_MULTIMAP:
raise TypeError(f'{cls.__name__} takes ARTIFACT_MULTIMAP but got '
f'{input_node.output_data_type.name} instead.')
return resolver_op.OpNode(
op_type=cls,
arg=input_node,
output_data_type=resolver_op.DataTypes.ARTIFACT_MULTIMAP,
kwargs=kwargs)

@deprecation_utils.deprecated(
Expand Down
53 changes: 47 additions & 6 deletions tfx/dsl/input_resolution/resolver_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,11 @@
# limitations under the License.
"""Module for ResolverOp and its related definitions."""
import abc
from typing import Any, ClassVar, Generic, Mapping, Type, TypeVar, Union
import enum
from typing import Any, ClassVar, Generic, Mapping, Type, TypeVar, Union, Sequence

import attr
import tfx.types
from tfx.utils import json_utils
from tfx.utils import typing_utils

Expand All @@ -33,6 +35,23 @@ class Context:
# run, and current running node information.


class DataTypes(enum.Enum):
"""Supported data types for ResolverOps input/outputs."""
ARTIFACT_LIST = Sequence[tfx.types.Artifact]
ARTIFACT_MULTIMAP = typing_utils.ArtifactMultiMap
ARTIFACT_MULTIMAP_LIST = Sequence[typing_utils.ArtifactMultiMap]

def is_acceptable(self, value: Any) -> bool:
"""Check the value is instance of the data type."""
if self == self.ARTIFACT_LIST:
return typing_utils.is_homogeneous_artifact_list(value)
elif self == self.ARTIFACT_MULTIMAP:
return typing_utils.is_artifact_multimap(value)
elif self == self.ARTIFACT_MULTIMAP_LIST:
return typing_utils.is_list_of_artifact_multimap(value)
raise NotImplementedError(f'Cannot check type for {self}.')


class _ResolverOpMeta(abc.ABCMeta):
"""Metaclass for ResolverOp.
Expand All @@ -49,12 +68,23 @@ class _ResolverOpMeta(abc.ABCMeta):
# method (cls).
# pylint: disable=no-value-for-parameter

def __init__(cls, name, bases, attrs):
def __new__(cls, name, bases, attrs, **kwargs):
# pylint: disable=too-many-function-args
return super().__new__(cls, name, bases, attrs)

def __init__(
cls, name, bases, attrs,
arg_data_types: Sequence[DataTypes] = (DataTypes.ARTIFACT_MULTIMAP,),
return_data_type: DataTypes = DataTypes.ARTIFACT_MULTIMAP):
cls._props_by_name = {
prop.name: prop
for prop in attrs.values()
if isinstance(prop, ResolverOpProperty)
}
if len(arg_data_types) != 1:
raise NotImplementedError('len(arg_data_types) should be 1.')
cls._arg_data_type = arg_data_types[0]
cls._return_data_type = return_data_type
super().__init__(name, bases, attrs)

def __call__(cls, arg: 'OpNode', **kwargs: Any):
Expand All @@ -77,14 +107,19 @@ def __call__(cls, arg: 'OpNode', **kwargs: Any):
"""
cls._check_arg(arg)
cls._check_kwargs(kwargs)
return OpNode(op_type=cls, arg=arg, kwargs=kwargs)
return OpNode(
op_type=cls,
arg=arg,
output_data_type=cls._return_data_type,
kwargs=kwargs)

def _check_arg(cls, arg: 'OpNode'):
# TODO(b/188020544): Type checking for arg operator's return type and
# current op's argument type.
if not isinstance(arg, OpNode):
raise ValueError('Cannot directly call ResolverOp with real values. Use '
'output of another operator as an argument.')
if arg.output_data_type != cls._arg_data_type:
raise TypeError(f'{cls.__name__} takes {cls._arg_data_type.name} type '
f'but got {arg.output_data_type.name} instead.')

def _check_kwargs(cls, kwargs: Mapping[str, Any]):
for name, prop in cls._props_by_name.items():
Expand Down Expand Up @@ -242,6 +277,9 @@ class OpNode(Generic[_TOut]):

# ResolverOp class that is used for the Node.
op_type = attr.ib()
# Output data type of ResolverOp.
output_data_type = attr.ib(default=DataTypes.ARTIFACT_MULTIMAP,
validator=attr.validators.instance_of(DataTypes))
# A single argument to the ResolverOp.
arg = attr.ib()
# ResolverOpProperty for the ResolverOp, given as keyword arguments.
Expand Down Expand Up @@ -276,5 +314,8 @@ def is_input_node(self):
return self is OpNode.INPUT_NODE

attr.set_run_validators(False)
OpNode.INPUT_NODE = OpNode(op_type=None, arg=None)
OpNode.INPUT_NODE = OpNode(
op_type=None,
output_data_type=DataTypes.ARTIFACT_MULTIMAP,
arg=None)
attr.set_run_validators(True)
30 changes: 30 additions & 0 deletions tfx/dsl/input_resolution/resolver_op_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tfx.dsl.input_resolution.resolver_op."""
import copy
from typing import Optional, Mapping

import tensorflow as tf
Expand All @@ -33,6 +34,23 @@ def apply(self, input_dict):
return input_dict


class Repeat(
resolver_op.ResolverOp,
return_data_type=resolver_op.DataTypes.ARTIFACT_MULTIMAP_LIST):
n = resolver_op.ResolverOpProperty(type=int)

def apply(self, input_dict):
return [copy.deepcopy(input_dict) for _ in range(self.n)]


class TakeLast(
resolver_op.ResolverOp,
arg_data_types=(resolver_op.DataTypes.ARTIFACT_MULTIMAP_LIST,)):

def apply(self, input_dicts):
return input_dicts[-1]


class ResolverOpTest(tf.test.TestCase):

def testDefineOp_PropertyDefaultViolatesType(self):
Expand Down Expand Up @@ -91,6 +109,18 @@ def testOpCreate_PropertyTypeCheck(self):
TypeError, "foo should be <class 'int'> but got '42'."):
Foo.create(foo='42')

def testOpCreate_ArgumentTypeCheck(self):
input_node = resolver_op.OpNode.INPUT_NODE

with self.subTest('Need List[Dict] but got Dict.'):
with self.assertRaisesRegex(
TypeError, 'TakeLast takes ARTIFACT_MULTIMAP_LIST type but got '
'ARTIFACT_MULTIMAP instead.'):
TakeLast(input_node)

with self.subTest('No Error'):
TakeLast(Repeat(input_node, n=2))

def testOpProperty_DefaultValue(self):
result = Bar.create()
self.assertEqual(result.bar, 'bar')
Expand Down
9 changes: 9 additions & 0 deletions tfx/utils/typing_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,15 @@
ArtifactMultiDict = Dict[str, List[tfx.types.Artifact]]


def is_homogeneous_artifact_list(value: Any) -> bool:
"""Checks value is Sequence[T] where T is subclass of Artifact."""
return (
isinstance(value, collections.abc.Sequence) and
(not value or
(issubclass(type(value[0]), tfx.types.Artifact) and
all(isinstance(v, type(value[0])) for v in value[1:]))))


def is_artifact_multimap(value: Any) -> bool:
"""Checks value is Mapping[str, Sequence[Artifact]] type."""
if not isinstance(value, collections.abc.Mapping):
Expand Down

0 comments on commit 7497850

Please sign in to comment.