Skip to content

Commit

Permalink
Implement other primitive typed artifacts: bytes/int/float.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 301904302
  • Loading branch information
tfx-copybara authored and tensorflow-extended-team committed Mar 19, 2020
1 parent 86ee553 commit a64ea68
Show file tree
Hide file tree
Showing 6 changed files with 130 additions and 20 deletions.
8 changes: 4 additions & 4 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Major Features and Improvements
* Add run/pipeline link when creating runs/pipelines on KFP through TFX CLI.
* Added support for ValueArtifact, whose attribute `value` allows users to
access the content of the underlying file directly in the executor. Support
Bytes/Integer/String/Float type. Note: interactive resolution does not
support this for now.

## Bug fixes and other changes
* Replaced relative import with absolute import in generated templates.
Expand Down Expand Up @@ -49,10 +53,6 @@
# Version 0.21.1

## Major Features and Improvements
* Added support for ValueArtifact, whose attribute `value` allows users to
access the content of the underlying file directly in the executor. The
first ValueArtifact type implemented was StringType. Note: interactive
resolution does not support this for now.
* Pipelines compiled using KubeflowDagRunner now defaults to using the
gRPC-based MLMD server deployed in Kubeflow Pipelines clusters when
performing operations on pipeline metadata.
Expand Down
5 changes: 2 additions & 3 deletions tfx/components/base/base_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,9 +151,8 @@ def resolve_input_artifacts(
# TODO(ccy): add this code path to interactive resolution.
for artifact in result[name]:
if isinstance(artifact, types.ValueArtifact):
# Resolve the content of file into value field for string-typed
# artifacts.
artifact.read()
# Resolve the content of file into value field for value artifacts.
_ = artifact.read()
return result

def resolve_exec_properties(
Expand Down
8 changes: 6 additions & 2 deletions tfx/types/artifact.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,12 +494,16 @@ def value(self, value):
self._value = value
self.write(value)

# Note: behavior of decode() method should not be changed to provide
# backward/forward compatibility.
@abc.abstractmethod
def decode(self, value) -> bytes:
def decode(self, serialized_value) -> bytes:
"""Method decoding the file content. Implemented by subclasses."""
pass

# Note: behavior of encode() method should not be changed to provide
# backward/forward compatibility.
@abc.abstractmethod
def encode(self, serialized_value) -> Any:
def encode(self, value) -> Any:
"""Method encoding the file content. Implemented by subclasses."""
pass
8 changes: 4 additions & 4 deletions tfx/types/artifact_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,12 @@ def deserialize_artifact(

# Search the whole Artifact type ontology for a matching class.
def find_subclasses(cls):
all_subclasses = []
result = []
for subclass in cls.__subclasses__():
all_subclasses.append(subclass)
all_subclasses.extend(find_subclasses(subclass))
result.append(subclass)
result.extend(find_subclasses(subclass))

return all_subclasses
return result

for cls in find_subclasses(Artifact):
if cls.TYPE_NAME == artifact_type.name:
Expand Down
57 changes: 50 additions & 7 deletions tfx/types/standard_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
from __future__ import division
from __future__ import print_function

import struct
from typing import Text

from tfx.types.artifact import Artifact
Expand Down Expand Up @@ -92,19 +93,61 @@ class Schema(Artifact):
TYPE_NAME = 'Schema'


class BytesType(ValueArtifact):
"""Artifacts representing raw bytes."""
TYPE_NAME = 'BytesType'

def encode(self, value: bytes):
if not isinstance(value, bytes):
raise TypeError('Expecting bytes but got value %s of type %s' %
(str(value), type(value)))
return value

def decode(self, serialized_value: bytes):
return serialized_value


class StringType(ValueArtifact):
"""String-typed Artifact."""
"""String-typed artifact."""
TYPE_NAME = 'StringType'

# Note, currently we enforce unicode-encoded string.
def encode(self, value: Text):
# TODO(jxzheng): Enforce value to be Text after dropping support of Py2.
# likely after 0.21.2
# assert isinstance(value, Text), value
def encode(self, value: Text) -> bytes:
if not isinstance(value, Text):
raise TypeError('Expecting Text but got value %s of type %s' %
(str(value), type(value)))
return value.encode('utf-8')

def decode(self, value: bytes):
return value.decode('utf-8')
def decode(self, serialized_value: bytes) -> Text:
return serialized_value.decode('utf-8')


class IntegerType(ValueArtifact):
"""Integer-typed artifact."""
TYPE_NAME = 'IntegerType'

def encode(self, value: int) -> bytes:
if not isinstance(value, int):
raise TypeError('Expecting int but got value %s of type %s' %
(str(value), type(value)))
return struct.pack('>i', value)

def decode(self, serialized_value: bytes) -> int:
return struct.unpack('>i', serialized_value)[0]


class FloatType(ValueArtifact):
"""Float-typed artifact."""
TYPE_NAME = 'FloatType'

def encode(self, value: float) -> bytes:
if not isinstance(value, float):
raise TypeError('Expecting float but got value %s of type %s' %
(str(value), type(value)))
return struct.pack('>d', value)

def decode(self, serialized_value: bytes) -> float:
return struct.unpack('>d', serialized_value)[0]


class TransformGraph(Artifact):
Expand Down
64 changes: 64 additions & 0 deletions tfx/types/standard_artifacts_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,64 @@
# Lint as: python2, python3
# Copyright 2020 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for standard TFX Artifact types."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import tensorflow as tf

from tfx.types import standard_artifacts

# Define constant value for tests.
_TEST_BYTE_RAW = b'hello world'
_TEST_BYTE_DECODED = b'hello world'

_TEST_STRING_RAW = b'hello world'
_TEST_STRING_DECODED = u'hello world'

_TEST_INT_RAW = b'\x01%\xe5\x91'
_TEST_INT_DECODED = 19260817

_TEST_FLOAT_RAW = b'@\t!\xfbTA\x17D'
_TEST_FLOAT_DECODED = 3.1415926535


class StandardArtifactsTest(tf.test.TestCase):

def testBytesType(self):
instance = standard_artifacts.BytesType()
self.assertEqual(_TEST_BYTE_RAW, instance.encode(_TEST_BYTE_DECODED))
self.assertEqual(_TEST_BYTE_DECODED, instance.decode(_TEST_BYTE_RAW))

def testStringType(self):
instance = standard_artifacts.StringType()
self.assertEqual(_TEST_STRING_RAW, instance.encode(_TEST_STRING_DECODED))
self.assertEqual(_TEST_STRING_DECODED, instance.decode(_TEST_STRING_RAW))

def testIntegerType(self):
instance = standard_artifacts.IntegerType()
self.assertEqual(_TEST_INT_RAW, instance.encode(_TEST_INT_DECODED))
self.assertEqual(_TEST_INT_DECODED, instance.decode(_TEST_INT_RAW))

def testFloatType(self):
instance = standard_artifacts.FloatType()
self.assertEqual(_TEST_FLOAT_RAW, instance.encode(_TEST_FLOAT_DECODED))
self.assertAlmostEqual(_TEST_FLOAT_DECODED,
instance.decode(_TEST_FLOAT_RAW))


if __name__ == '__main__':
tf.test.main()

0 comments on commit a64ea68

Please sign in to comment.