Skip to content

Commit

Permalink
ArtifactPlaceholder defaults to access first artifact if not specified.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 360594885
  • Loading branch information
tfx-copybara committed Mar 3, 2021
1 parent 7a08c23 commit 5d4448c
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 15 deletions.
14 changes: 12 additions & 2 deletions tfx/dsl/placeholder/placeholder.py
Original file line number Diff line number Diff line change
Expand Up @@ -320,22 +320,30 @@ class ArtifactPlaceholder(Placeholder):

@property
def uri(self):
self._try_inject_index_operator()
self._operators.append(_ArtifactUriOperator())
return self

def split_uri(self, split: str):
self._try_inject_index_operator()
self._operators.append(_ArtifactUriOperator(split))
return self

@property
def value(self):
self._try_inject_index_operator()
self._operators.append(_ArtifactValueOperator())
return self

def __getitem__(self, key: int):
self._operators.append(_IndexOperator(key))
return self

def _try_inject_index_operator(self):
if not self._operators or not isinstance(self._operators[-1],
_IndexOperator):
self._operators.append(_IndexOperator(0))


class _ProtoAccessiblePlaceholder(Placeholder, abc.ABC):
"""A base Placeholder for accessing proto fields using Python proto syntax."""
Expand Down Expand Up @@ -420,7 +428,8 @@ def input(key: str) -> ArtifactPlaceholder: # pylint: disable=redefined-builtin
1. Rendering the whole MLMD artifact proto as text_format.
Example: input('model')
2. Accessing a specific index using [index], if multiple artifacts are
associated with the given key.
associated with the given key. If not specified, default to the first
artifact.
Example: input('model')[0]
3. Getting the URI of an artifact through .uri property.
Example: input('model').uri or input('model')[0].uri
Expand Down Expand Up @@ -449,7 +458,8 @@ def output(key: str) -> ArtifactPlaceholder:
1. Rendering the whole artifact as text_format.
Example: output('model')
2. Accessing a specific index using [index], if multiple artifacts are
associated with the given key.
associated with the given key. If not specified, default to the first
artifact.
Example: output('model')[0]
3. Getting the URI of an artifact through .uri property.
Example: output('model').uri or output('model')[0].uri
Expand Down
53 changes: 40 additions & 13 deletions tfx/dsl/placeholder/placeholder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,15 +38,22 @@ def _assert_placeholder_pb_equal_and_deepcopyable(self, placeholder,
del placeholder
self.assertProtoEquals(placeholder_copy.encode(), expected_pb)

def testArtifactUriSimple(self):
def testArtifactUriWithDefault0Index(self):
self._assert_placeholder_pb_equal_and_deepcopyable(
ph.input('model').uri, """
operator {
artifact_uri_op {
expression {
placeholder {
type: INPUT_ARTIFACT
key: "model"
operator {
index_op {
expression {
placeholder {
type: INPUT_ARTIFACT
key: "model"
}
}
index: 0
}
}
}
}
Expand Down Expand Up @@ -104,9 +111,16 @@ def testPrimitiveArtifactValue(self):
operator {
artifact_value_op {
expression {
placeholder {
type: INPUT_ARTIFACT
key: "primitive"
operator {
index_op {
expression {
placeholder {
type: INPUT_ARTIFACT
key: "primitive"
}
}
index: 0
}
}
}
}
Expand All @@ -122,9 +136,15 @@ def testConcatUriWithString(self):
operator {
artifact_uri_op {
expression {
placeholder {
type: OUTPUT_ARTIFACT
key: "model"
operator {
index_op {
expression {
placeholder {
type: OUTPUT_ARTIFACT
key: "model"
}
}
}
}
}
}
Expand Down Expand Up @@ -193,9 +213,16 @@ def testComplicatedConcat(self):
operator {
artifact_uri_op {
expression {
placeholder {
type: OUTPUT_ARTIFACT
key: "model"
operator {
index_op {
expression {
placeholder {
type: OUTPUT_ARTIFACT
key: "model"
}
}
index: 0
}
}
}
}
Expand Down

0 comments on commit 5d4448c

Please sign in to comment.