Skip to content

Commit

Permalink
Introduce channel.union and UnionChannel class to support multiple Ch…
Browse files Browse the repository at this point in the history
…annels

PiperOrigin-RevId: 400776676
  • Loading branch information
tfx-copybara committed Oct 4, 2021
1 parent f87be52 commit e032862
Show file tree
Hide file tree
Showing 7 changed files with 204 additions and 86 deletions.
1 change: 1 addition & 0 deletions tfx/types/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
"""Subpackage for TFX pipeline types."""

from tfx.types.artifact import Artifact
from tfx.types.channel import BaseChannel
from tfx.types.channel import Channel
from tfx.types.channel import ExecPropertyTypes
from tfx.types.channel import Property # Type alias.
Expand Down
90 changes: 69 additions & 21 deletions tfx/types/channel.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
import inspect
import json
import textwrap
from typing import Any, Dict, Iterable, List, Optional, Type, Union
from typing import Any, cast, Dict, Iterable, List, Optional, Type, Union
from absl import logging

from tfx.dsl.placeholder import placeholder
Expand All @@ -34,15 +34,32 @@
ExecPropertyTypes = Union[int, float, str, bool, message.Message, List[Any]]


class Channel(json_utils.Jsonable):
class BaseChannel:
"""An abstract type for Channels that connects pipeline nodes.
Attributes:
type: The artifact type class that the Channel takes.
"""

def __init__(self, type: Type[Artifact]): # pylint: disable=redefined-builtin
if not (inspect.isclass(type) and issubclass(type, Artifact)): # pytype: disable=wrong-arg-types
raise ValueError(
'Argument "type" of BaseChannel constructor must be a subclass of '
'tfx.Artifact (got %r).' % (type,))
self.type = type

@property
def type_name(self):
"""Name of the artifact type class that Channel takes."""
return self.type.TYPE_NAME


class Channel(json_utils.Jsonable, BaseChannel):
"""Tfx Channel.
TFX Channel is an abstract concept that connects data producers and data
consumers. It contains restriction of the artifact type that should be fed
into or read from it.
Attributes:
type: The artifact type class that the Channel takes.
"""

# TODO(b/125348988): Add support for real Channel in addition to static ones.
Expand Down Expand Up @@ -70,16 +87,11 @@ def __init__(
producer_component_id: (Optional) Producer component id of the Channel.
This argument is internal/experimental and is subject to change in the
future.
output_key: (Optional) The output key when producer component produces
the artifacts in this Channel. This argument is internal/experimental
and is subject to change in the future.
output_key: (Optional) The output key when producer component produces the
artifacts in this Channel. This argument is internal/experimental and is
subject to change in the future.
"""
if not (inspect.isclass(type) and issubclass(type, Artifact)): # pytype: disable=wrong-arg-types
raise ValueError(
'Argument "type" of Channel constructor must be a subclass of '
'tfx.Artifact (got %r).' % (type,))

self.type = type
super().__init__(type=type)

self.additional_properties = additional_properties or {}
self.additional_custom_properties = additional_custom_properties or {}
Expand All @@ -94,11 +106,6 @@ def __init__(
self._artifacts = []
self._matching_channel_name = None

@property
def type_name(self):
"""Name of the artifact type class that Channel takes."""
return self.type.TYPE_NAME

def __repr__(self):
artifacts_str = '\n '.join(repr(a) for a in self._artifacts)
return textwrap.dedent("""\
Expand Down Expand Up @@ -168,8 +175,10 @@ def to_json_dict(self) -> Dict[str, Any]:
preserving_proto_field_name=True)),
'artifacts':
list(a.to_json_dict() for a in self._artifacts),
'additional_properties': self.additional_properties,
'additional_custom_properties': self.additional_custom_properties,
'additional_properties':
self.additional_properties,
'additional_custom_properties':
self.additional_custom_properties,
'producer_component_id':
(self.producer_component_id if self.producer_component_id else None
),
Expand All @@ -196,3 +205,42 @@ def from_json_dict(cls, dict_data: Dict[str, Any]) -> Any:

def future(self) -> placeholder.ChannelWrappedPlaceholder:
return placeholder.ChannelWrappedPlaceholder(self)


@doc_controls.do_not_generate_docs
class UnionChannel(BaseChannel):
"""Union of multiple Channels with the same type.
Prefer to use union() to create UnionChannel.
"""

def __init__(self, type: Type[Artifact], input_channels: List[BaseChannel]): # pylint: disable=redefined-builtin
super().__init__(type=type)

if not input_channels:
raise ValueError('At least one input channel expected.')

self.channels = []
for c in input_channels:
if isinstance(c, UnionChannel):
self.channels.extend(cast(UnionChannel, c).channels)
elif isinstance(c, Channel):
self.channels.append(c)
else:
raise ValueError('Unexpected channel type: %s.' % c.type_name)

self._validate_type()

def _validate_type(self):
for channel in self.channels:
if not isinstance(channel, Channel) or channel.type != self.type:
raise TypeError(
'Unioned channels must have the same type. Expected %s (got %s).' %
(self.type, channel.type))


def union(input_channels: Iterable[BaseChannel]) -> UnionChannel:
"""Convenient method to combine multiple input channels into union channel."""
input_channels = list(input_channels)
assert input_channels, 'Not expecting empty input channels list.'
return UnionChannel(input_channels[0].type, input_channels)
67 changes: 43 additions & 24 deletions tfx/types/channel_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,22 +15,20 @@

import tensorflow as tf
from tfx.dsl.placeholder import placeholder
from tfx.types.artifact import Artifact
from tfx.types.artifact import Property
from tfx.types.artifact import PropertyType
from tfx.types.channel import Channel
from tfx.types import artifact
from tfx.types import channel

from ml_metadata.proto import metadata_store_pb2


class _MyType(Artifact):
class _MyType(artifact.Artifact):
TYPE_NAME = 'MyTypeName'
PROPERTIES = {
'string_value': Property(PropertyType.STRING),
'string_value': artifact.Property(artifact.PropertyType.STRING),
}


class _AnotherType(Artifact):
class _AnotherType(artifact.Artifact):
TYPE_NAME = 'AnotherTypeName'


Expand All @@ -39,58 +37,79 @@ class ChannelTest(tf.test.TestCase):
def testValidChannel(self):
instance_a = _MyType()
instance_b = _MyType()
chnl = Channel(_MyType).set_artifacts([instance_a, instance_b])
chnl = channel.Channel(_MyType).set_artifacts([instance_a, instance_b])
self.assertEqual(chnl.type_name, 'MyTypeName')
self.assertCountEqual(chnl.get(), [instance_a, instance_b])

def testInvalidChannelType(self):
instance_a = _MyType()
instance_b = _MyType()
with self.assertRaises(ValueError):
Channel(_AnotherType).set_artifacts([instance_a, instance_b])
channel.Channel(_AnotherType).set_artifacts([instance_a, instance_b])

def testStringTypeNameNotAllowed(self):
with self.assertRaises(ValueError):
Channel('StringTypeName')
channel.Channel('StringTypeName')

def testJsonRoundTrip(self):
channel = Channel(
chnl = channel.Channel(
type=_MyType,
additional_properties={
'string_value': metadata_store_pb2.Value(string_value='forty-two')
},
additional_custom_properties={
'int_value': metadata_store_pb2.Value(int_value=42)
})
serialized = channel.to_json_dict()
rehydrated = Channel.from_json_dict(serialized)
self.assertIs(channel.type, rehydrated.type)
self.assertEqual(channel.type_name, rehydrated.type_name)
self.assertEqual(channel.additional_properties,
serialized = chnl.to_json_dict()
rehydrated = channel.Channel.from_json_dict(serialized)
self.assertIs(chnl.type, rehydrated.type)
self.assertEqual(chnl.type_name, rehydrated.type_name)
self.assertEqual(chnl.additional_properties,
rehydrated.additional_properties)
self.assertEqual(channel.additional_custom_properties,
self.assertEqual(chnl.additional_custom_properties,
rehydrated.additional_custom_properties)

def testJsonRoundTripUnknownArtifactClass(self):
channel = Channel(type=_MyType)
chnl = channel.Channel(type=_MyType)

serialized = channel.to_json_dict()
serialized = chnl.to_json_dict()
serialized['type']['name'] = 'UnknownTypeName'

rehydrated = Channel.from_json_dict(serialized)
rehydrated = channel.Channel.from_json_dict(serialized)
self.assertEqual('UnknownTypeName', rehydrated.type_name)
self.assertEqual(channel.type._get_artifact_type().properties,
self.assertEqual(chnl.type._get_artifact_type().properties,
rehydrated.type._get_artifact_type().properties)
self.assertTrue(rehydrated.type._AUTOGENERATED)

def testFutureProducesPlaceholder(self):
channel = Channel(type=_MyType)
future = channel.future()
chnl = channel.Channel(type=_MyType)
future = chnl.future()
self.assertIsInstance(future, placeholder.ChannelWrappedPlaceholder)
self.assertIs(future.channel, channel)
self.assertIs(future.channel, chnl)
self.assertIsInstance(future[0], placeholder.ChannelWrappedPlaceholder)
self.assertIsInstance(future.value, placeholder.ChannelWrappedPlaceholder)

def testValidUnionChannel(self):
channel1 = channel.Channel(type=_MyType)
channel2 = channel.Channel(type=_MyType)
union_channel = channel.union([channel1, channel2])
self.assertIs(union_channel.type_name, 'MyTypeName')
self.assertEqual(union_channel.channels, [channel1, channel2])

union_channel = channel.union([channel1, channel.union([channel2])])
self.assertIs(union_channel.type_name, 'MyTypeName')
self.assertEqual(union_channel.channels, [channel1, channel2])

def testMismatchedUnionChannelType(self):
chnl = channel.Channel(type=_MyType)
another_channel = channel.Channel(type=_AnotherType)
with self.assertRaises(TypeError):
channel.union([chnl, another_channel])

def testEmptyUnionChannel(self):
with self.assertRaises(AssertionError):
channel.union([])


if __name__ == '__main__':
tf.test.main()
34 changes: 27 additions & 7 deletions tfx/types/channel_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
# limitations under the License.
"""TFX Channel utilities."""

from typing import Dict, Iterable, List
from typing import cast, Dict, Iterable, List

from tfx.types.artifact import Artifact
from tfx.types.channel import Channel
from tfx.types import artifact
from tfx.types import channel


def as_channel(artifacts: Iterable[Artifact]) -> Channel:
def as_channel(artifacts: Iterable[artifact.Artifact]) -> channel.Channel:
"""Converts artifact collection of the same artifact type into a Channel.
Args:
Expand All @@ -33,16 +33,17 @@ def as_channel(artifacts: Iterable[Artifact]) -> Channel:
"""
try:
first_element = next(iter(artifacts))
if isinstance(first_element, Artifact):
return Channel(type=first_element.type).set_artifacts(artifacts)
if isinstance(first_element, artifact.Artifact):
return channel.Channel(type=first_element.type).set_artifacts(artifacts)
else:
raise ValueError('Invalid artifact iterable: {}'.format(artifacts))
except StopIteration:
raise ValueError('Cannot convert empty artifact iterable into Channel')


def unwrap_channel_dict(
channel_dict: Dict[str, Channel]) -> Dict[str, List[Artifact]]:
channel_dict: Dict[str,
channel.Channel]) -> Dict[str, List[artifact.Artifact]]:
"""Unwrap dict of channels to dict of lists of Artifact.
Args:
Expand All @@ -52,3 +53,22 @@ def unwrap_channel_dict(
a dict of Text -> List[Artifact]
"""
return dict((k, list(v.get())) for k, v in channel_dict.items())


def get_individual_channels(
input_channel: channel.BaseChannel) -> List[channel.Channel]:
"""Converts BaseChannel into a list of Channels."""
if isinstance(input_channel, channel.Channel):
return [input_channel]
elif isinstance(input_channel, channel.UnionChannel):
return list(cast(channel.UnionChannel, input_channel).channels)
else:
raise RuntimeError(f'Unexpected Channel type: {type(input_channel)}')


def get_channel_producer_component_ids(
input_channel: channel.BaseChannel) -> List[str]:
"""Returns a list of producer_component_id for input BaseChannel."""
return [
c.producer_component_id for c in get_individual_channels(input_channel)
]
22 changes: 18 additions & 4 deletions tfx/types/channel_utils_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@
"""Tests for tfx.utils.channel."""

import tensorflow as tf
from tfx.types import artifact
from tfx.types import channel
from tfx.types import channel_utils
from tfx.types.artifact import Artifact
from tfx.types.channel import Channel


class _MyArtifact(Artifact):
class _MyArtifact(artifact.Artifact):
TYPE_NAME = 'MyTypeName'


Expand All @@ -45,11 +45,25 @@ def testUnwrapChannelDict(self):
instance_a = _MyArtifact()
instance_b = _MyArtifact()
channel_dict = {
'id': Channel(_MyArtifact).set_artifacts([instance_a, instance_b])
'id':
channel.Channel(_MyArtifact).set_artifacts([instance_a, instance_b])
}
result = channel_utils.unwrap_channel_dict(channel_dict)
self.assertDictEqual(result, {'id': [instance_a, instance_b]})

def testGetInidividualChannels(self):
instance_a = _MyArtifact()
instance_b = _MyArtifact()
one_channel = channel.Channel(_MyArtifact).set_artifacts([instance_a])
another_channel = channel.Channel(_MyArtifact).set_artifacts([instance_b])

result = channel_utils.get_individual_channels(one_channel)
self.assertEqual(result, [one_channel])

result = channel_utils.get_individual_channels(
channel.union([one_channel, another_channel]))
self.assertEqual(result, [one_channel, another_channel])


if __name__ == '__main__':
tf.test.main()
Loading

0 comments on commit e032862

Please sign in to comment.