Skip to content

Commit

Permalink
PY2 cleanup on tfx/types.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 373506553
  • Loading branch information
chongkong authored and tfx-copybara committed May 13, 2021
1 parent f48399c commit 5194606
Show file tree
Hide file tree
Showing 19 changed files with 144 additions and 240 deletions.
1 change: 0 additions & 1 deletion tfx/types/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down
73 changes: 34 additions & 39 deletions tfx/types/artifact.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,16 +13,12 @@
# limitations under the License.
"""TFX artifact type definition."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import builtins
import copy
import enum
import importlib
import json
from typing import Any, Dict, List, Optional, Text, Type, Union
from typing import Any, Dict, List, Optional, Type, Union

from absl import logging
from tfx.utils import doc_controls
Expand All @@ -34,7 +29,7 @@
from ml_metadata.proto import metadata_store_pb2


class ArtifactState(object):
class ArtifactState:
"""Enumeration of possible Artifact states."""

# Indicates that there is a pending execution producing the artifact.
Expand Down Expand Up @@ -75,7 +70,7 @@ class PropertyType(enum.Enum):
JSON_VALUE = 4


class Property(object):
class Property:
"""Property specified for an Artifact."""
_ALLOWED_MLMD_TYPES = {
PropertyType.INT: metadata_store_pb2.INT,
Expand All @@ -97,7 +92,7 @@ def __repr__(self):
return str(self.type)


JsonValueType = Union[Dict, List, int, float, type(None), Text]
JsonValueType = Union[Dict, List, int, float, type(None), str]
_JSON_SINGLE_VALUE_KEY = '__value__'


Expand Down Expand Up @@ -220,7 +215,7 @@ def __init__(
def _get_artifact_type(cls):
if not getattr(cls, '_MLMD_ARTIFACT_TYPE', None):
type_name = cls.TYPE_NAME
if not (type_name and isinstance(type_name, (str, Text))):
if not (type_name and isinstance(type_name, str)):
raise ValueError(
('The Artifact subclass %s must override the TYPE_NAME attribute '
'with a string type name identifier (got %r instead).') %
Expand All @@ -234,7 +229,7 @@ def _get_artifact_type(cls):
'Artifact subclass %s.PROPERTIES is not a dictionary.' % cls)
for key, value in cls.PROPERTIES.items():
if not (isinstance(key,
(Text, bytes)) and isinstance(value, Property)):
(str, bytes)) and isinstance(value, Property)):
raise ValueError(
('Artifact subclass %s.PROPERTIES dictionary must have keys of '
'type string and values of type artifact.Property.') % cls)
Expand All @@ -245,7 +240,7 @@ def _get_artifact_type(cls):
cls._MLMD_ARTIFACT_TYPE = artifact_type
return copy.deepcopy(cls._MLMD_ARTIFACT_TYPE)

def __getattr__(self, name: Text) -> Any:
def __getattr__(self, name: str) -> Any:
"""Custom __getattr__ to allow access to artifact properties."""
if name == '_artifact_type':
# Prevent infinite recursion when used with copy.deepcopy().
Expand Down Expand Up @@ -285,7 +280,7 @@ def __getattr__(self, name: Text) -> Any:
raise Exception('Unknown MLMD type %r for property %r.' %
(property_mlmd_type, name))

def __setattr__(self, name: Text, value: Any):
def __setattr__(self, name: str, value: Any):
"""Custom __setattr__ to allow access to artifact properties."""
if not self._initialized:
object.__setattr__(self, name, value)
Expand All @@ -304,7 +299,7 @@ def __setattr__(self, name: Text, value: Any):
(name, self))
property_mlmd_type = self._artifact_type.properties[name]
if property_mlmd_type == metadata_store_pb2.STRING:
if not isinstance(value, (Text, bytes)):
if not isinstance(value, (str, bytes)):
raise Exception(
'Expected string value for property %r; got %r instead.' %
(name, value))
Expand Down Expand Up @@ -358,7 +353,7 @@ def __repr__(self):
str(self.mlmd_artifact), str(self._artifact_type))

@doc_controls.do_not_doc_inheritable
def to_json_dict(self) -> Dict[Text, Any]:
def to_json_dict(self) -> Dict[str, Any]:
return {
'artifact':
json.loads(
Expand All @@ -378,7 +373,7 @@ def to_json_dict(self) -> Dict[Text, Any]:

@classmethod
@doc_controls.do_not_doc_inheritable
def from_json_dict(cls, dict_data: Dict[Text, Any]) -> Any:
def from_json_dict(cls, dict_data: Dict[str, Any]) -> Any:
module_name = dict_data['__artifact_class_module__']
class_name = dict_data['__artifact_class_name__']
artifact = metadata_store_pb2.Artifact()
Expand Down Expand Up @@ -451,12 +446,12 @@ def mlmd_artifact(self):
# Settable properties for all artifact types.
@property
@doc_controls.do_not_doc_in_subclasses
def uri(self) -> Text:
def uri(self) -> str:
"""Artifact URI."""
return self._artifact.uri

@uri.setter
def uri(self, uri: Text):
def uri(self, uri: str):
"""Setter for artifact URI."""
self._artifact.uri = uri

Expand Down Expand Up @@ -501,14 +496,14 @@ def type_id(self, type_id: int):
# - producer_component: The name of the component that produces the
# artifact (in a subsequent change, this information will move to the
# associated ML Metadata Event object).
def _get_system_property(self, key: Text) -> Text:
def _get_system_property(self, key: str) -> str:
if (key in self._artifact_type.properties and
key in self._artifact.properties):
# Legacy artifact types which have explicitly defined system properties.
return self._artifact.properties[key].string_value
return self._artifact.custom_properties[key].string_value

def _set_system_property(self, key: Text, value: Text):
def _set_system_property(self, key: str, value: str):
if (key in self._artifact_type.properties and
key in self._artifact.properties):
# Clear non-custom property in legacy artifact types.
Expand All @@ -517,85 +512,85 @@ def _set_system_property(self, key: Text, value: Text):

@property
@doc_controls.do_not_doc_inheritable
def name(self) -> Text:
def name(self) -> str:
"""Name of the underlying mlmd artifact."""
return self._get_system_property('name')

@name.setter
def name(self, name: Text):
def name(self, name: str):
"""Set name of the underlying artifact."""
self._set_system_property('name', name)

@property
@doc_controls.do_not_doc_in_subclasses
def state(self) -> Text:
def state(self) -> str:
"""State of the underlying mlmd artifact."""
return self._get_system_property('state')

@state.setter
def state(self, state: Text):
def state(self, state: str):
"""Set state of the underlying artifact."""
self._set_system_property('state', state)

@property
@doc_controls.do_not_doc_in_subclasses
def pipeline_name(self) -> Text:
def pipeline_name(self) -> str:
"""Name of the pipeline that produce the artifact."""
return self._get_system_property('pipeline_name')

@pipeline_name.setter
def pipeline_name(self, pipeline_name: Text):
def pipeline_name(self, pipeline_name: str):
"""Set name of the pipeline that produce the artifact."""
self._set_system_property('pipeline_name', pipeline_name)

@property
@doc_controls.do_not_doc_inheritable
def producer_component(self) -> Text:
def producer_component(self) -> str:
"""Producer component of the artifact."""
return self._get_system_property('producer_component')

@producer_component.setter
def producer_component(self, producer_component: Text):
def producer_component(self, producer_component: str):
"""Set producer component of the artifact."""
self._set_system_property('producer_component', producer_component)

# Custom property accessors.
@doc_controls.do_not_doc_in_subclasses
def set_string_custom_property(self, key: Text, value: Text):
def set_string_custom_property(self, key: str, value: str):
"""Set a custom property of string type."""
self._artifact.custom_properties[key].string_value = value

@doc_controls.do_not_doc_in_subclasses
def set_int_custom_property(self, key: Text, value: int):
def set_int_custom_property(self, key: str, value: int):
"""Set a custom property of int type."""
self._artifact.custom_properties[key].int_value = builtins.int(value)

@doc_controls.do_not_doc_in_subclasses
def set_float_custom_property(self, key: Text, value: float):
def set_float_custom_property(self, key: str, value: float):
"""Sets a custom property of float type."""
self._artifact.custom_properties[key].double_value = builtins.float(value)

@doc_controls.do_not_doc_inheritable
def set_json_value_custom_property(self, key: Text, value: JsonValueType):
def set_json_value_custom_property(self, key: str, value: JsonValueType):
"""Sets a custom property of float type."""
self._cached_json_value_custom_properties[key] = value

@doc_controls.do_not_doc_in_subclasses
def has_custom_property(self, key: Text) -> bool:
def has_custom_property(self, key: str) -> bool:
return key in self._artifact.custom_properties

@doc_controls.do_not_doc_in_subclasses
def get_string_custom_property(self, key: Text) -> Text:
def get_string_custom_property(self, key: str) -> str:
"""Get a custom property of string type."""
if key not in self._artifact.custom_properties:
return ''
json_value = self.get_json_value_custom_property(key)
if isinstance(json_value, Text):
if isinstance(json_value, str):
return json_value
return self._artifact.custom_properties[key].string_value

@doc_controls.do_not_doc_in_subclasses
def get_int_custom_property(self, key: Text) -> int:
def get_int_custom_property(self, key: str) -> int:
"""Get a custom property of int type."""
if key not in self._artifact.custom_properties:
return 0
Expand All @@ -606,7 +601,7 @@ def get_int_custom_property(self, key: Text) -> int:

# TODO(b/179215351): Standardize type name into one of float and double.
@doc_controls.do_not_doc_in_subclasses
def get_float_custom_property(self, key: Text) -> float:
def get_float_custom_property(self, key: str) -> float:
"""Gets a custom property of float type."""
if key not in self._artifact.custom_properties:
return 0.0
Expand All @@ -616,7 +611,7 @@ def get_float_custom_property(self, key: Text) -> float:
return self._artifact.custom_properties[key].double_value

@doc_controls.do_not_doc_inheritable
def get_json_value_custom_property(self, key: Text) -> JsonValueType:
def get_json_value_custom_property(self, key: str) -> JsonValueType:
"""Get a custom property of int type."""
if key in self._cached_json_value_custom_properties:
return self._cached_json_value_custom_properties[key]
Expand Down Expand Up @@ -649,7 +644,7 @@ def copy_from(self, other: 'Artifact'):

def _ArtifactType( # pylint: disable=invalid-name
name: Optional[str] = None,
properties: Optional[Dict[Text, Property]] = None,
properties: Optional[Dict[str, Property]] = None,
mlmd_artifact_type: Optional[metadata_store_pb2.ArtifactType] = None
) -> Type[Artifact]:
"""Experimental interface: internal use only.
Expand Down
39 changes: 16 additions & 23 deletions tfx/types/artifact_test.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
# Lint as: python2, python3
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand All @@ -14,16 +13,10 @@
# limitations under the License.
"""Tests for tfx.types.artifact."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import json
import textwrap
from typing import Text
from unittest import mock

# Standard Imports

import absl
import tensorflow as tf
Expand Down Expand Up @@ -90,8 +83,8 @@ class _MyArtifact(artifact.Artifact):
class _MyValueArtifact(value_artifact.ValueArtifact):
TYPE_NAME = 'MyValueTypeName'

def encode(self, value: Text):
assert isinstance(value, Text), value
def encode(self, value: str):
assert isinstance(value, str), value
return value.encode('utf-8')

def decode(self, value: bytes):
Expand Down Expand Up @@ -123,10 +116,10 @@ def testArtifact(self):
self.assertEqual('', instance.state)

# Default property does not have span or split_names.
with self.assertRaisesRegexp(AttributeError, "has no property 'span'"):
with self.assertRaisesRegex(AttributeError, "has no property 'span'"):
instance.span # pylint: disable=pointless-statement
with self.assertRaisesRegexp(AttributeError,
"has no property 'split_names'"):
with self.assertRaisesRegex(AttributeError,
"has no property 'split_names'"):
instance.split_names # pylint: disable=pointless-statement

# Test property setters.
Expand All @@ -143,11 +136,11 @@ def testArtifact(self):
self.assertEqual(artifact.ArtifactState.DELETED, instance.state)

# Default artifact does not have span.
with self.assertRaisesRegexp(AttributeError, "unknown property 'span'"):
with self.assertRaisesRegex(AttributeError, "unknown property 'span'"):
instance.span = 20190101
# Default artifact does not have span.
with self.assertRaisesRegexp(AttributeError,
"unknown property 'split_names'"):
with self.assertRaisesRegex(AttributeError,
"unknown property 'split_names'"):
instance.split_names = ''

instance.set_int_custom_property('int_key', 20)
Expand Down Expand Up @@ -764,15 +757,15 @@ def testArtifactJsonValue(self):
)"""), str(copied_artifact))

def testInvalidArtifact(self):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError, 'The "mlmd_artifact_type" argument must be passed'):
artifact.Artifact()

class MyBadArtifact(artifact.Artifact):
# No TYPE_NAME
pass

with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
'The Artifact subclass .* must override the TYPE_NAME attribute '):
MyBadArtifact()
Expand All @@ -784,7 +777,7 @@ class MyNewArtifact(artifact.Artifact):
MyNewArtifact()

# Not okay to pass type_name on subclass.
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
'The "mlmd_artifact_type" argument must not be passed for Artifact '
'subclass'):
Expand All @@ -808,20 +801,20 @@ def testArtifactProperties(self):
self.assertEqual(my_artifact.get_int_custom_property('invalid'), 0)
self.assertNotIn('invalid', my_artifact._artifact.custom_properties)

with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AttributeError, "Cannot set unknown property 'invalid' on artifact"):
my_artifact.invalid = 1

with self.assertRaisesRegexp(
with self.assertRaisesRegex(
AttributeError, "Cannot set unknown property 'invalid' on artifact"):
my_artifact.invalid = 'x'

with self.assertRaisesRegexp(AttributeError,
"Artifact has no property 'invalid'"):
with self.assertRaisesRegex(AttributeError,
"Artifact has no property 'invalid'"):
my_artifact.invalid # pylint: disable=pointless-statement

def testStringTypeNameNotAllowed(self):
with self.assertRaisesRegexp(
with self.assertRaisesRegex(
ValueError,
'The "mlmd_artifact_type" argument must be an instance of the proto '
'message'):
Expand Down
Loading

0 comments on commit 5194606

Please sign in to comment.