Skip to content

Commit

Permalink
Support partial run on subpipelines
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 591970925
  • Loading branch information
kmonte authored and tfx-copybara committed Dec 18, 2023
1 parent a6345d4 commit 1bf082d
Show file tree
Hide file tree
Showing 3 changed files with 194 additions and 168 deletions.
6 changes: 6 additions & 0 deletions tfx/dsl/compiler/compiler_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,12 @@ def pipeline_end_node_id_from_pipeline_id(pipeline_id: str) -> str:
return f"{pipeline_id}{constants.PIPELINE_END_NODE_SUFFIX}"


def end_node_context_name_from_subpipeline_id(subpipeline_id: str) -> str:
"""Builds the end_node context name of a composable pipeline."""
end_node_id = pipeline_end_node_id_from_pipeline_id(subpipeline_id)
return node_context_name(subpipeline_id, end_node_id)


def node_context_name(pipeline_context_name: str, node_id: str):
"""Defines the name used to reference a node context in MLMD."""
return f"{pipeline_context_name}.{node_id}"
Expand Down
88 changes: 41 additions & 47 deletions tfx/orchestration/portable/partial_run_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,6 @@ def mark_pipeline(
"""
nodes = node_proto_view.get_view_for_all_in(pipeline)
_ensure_sync_pipeline(pipeline)
_ensure_no_subpipeline_nodes(nodes)
_ensure_no_partial_run_marks(nodes)
_ensure_not_full_run(from_nodes, to_nodes)
_ensure_no_missing_nodes(nodes, from_nodes, to_nodes)
Expand Down Expand Up @@ -228,28 +227,6 @@ def _ensure_sync_pipeline(pipeline: pipeline_pb2.Pipeline):
f'{pipeline_pb2.Pipeline.ExecutionMode.Name(pipeline.execution_mode)}')


def _ensure_no_subpipeline_nodes(
nodes: Sequence[node_proto_view.NodeProtoView],
):
"""Raises ValueError if the pipeline contains a sub-pipeline.
If the pipeline comes from the compiler, it should already be
flattened. This is just in case the IR proto was created in another way.
Args:
nodes: The nodes of the pipeline.
Raises:
ValueError: If the pipeline contains a sub-pipeline.
"""
for node in nodes:
if isinstance(node, node_proto_view.ComposablePipelineProtoView):
raise ValueError(
'Pipeline filtering not supported for pipelines with sub-pipelines. '
f'sub-pipeline found: {node.node_info.id}'
)


def _ensure_not_full_run(from_nodes: Optional[Collection[str]] = None,
to_nodes: Optional[Collection[str]] = None):
"""Raises ValueError if both from_nodes and to_nodes are falsy."""
Expand Down Expand Up @@ -536,6 +513,9 @@ def _reuse_pipeline_run_artifacts(
for node in node_proto_view.get_view_for_all_in(marked_pipeline)
if _should_attempt_to_reuse_artifact(node.execution_options)
]
logging.info(
'Reusing nodes: %s', [n.node_info.id for n in reuse_nodes]
)
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = {}
for node in reuse_nodes:
Expand All @@ -544,7 +524,7 @@ def _reuse_pipeline_run_artifacts(
node,
executor.submit(
artifact_recycler.reuse_node_outputs,
node_id=node_id,
node=node
),
)

Expand Down Expand Up @@ -628,6 +608,7 @@ def _get_base_pipeline_run_context(
if ctx.type_id == pipeline_run_type_id
}
else:
logging.info('No child contexts found. Falling back to previous logic.')
# The parent-child relationship between pipeline and pipeline run contexts
# is set up after the partial run feature is available to users. For
# existing pipelines, we need to fall back to the previous logic.
Expand All @@ -637,7 +618,6 @@ def _get_base_pipeline_run_context(
constants.PIPELINE_RUN_CONTEXT_TYPE_NAME
)
}

if base_run_id:
if base_run_id in pipeline_run_contexts:
return pipeline_run_contexts[base_run_id]
Expand All @@ -659,30 +639,43 @@ def _get_base_pipeline_run_context(
)
return sorted_run_contexts[-1]

def _get_node_context(self, node_id: str) -> metadata_store_pb2.Context:
node_context_name = compiler_utils.node_context_name(
self._pipeline_name, node_id
)
node_context = self._node_context_by_name.get(node_context_name)
def _get_node_context(
self, node: node_proto_view.NodeProtoView
) -> metadata_store_pb2.Context:
"""Returns node context for node."""
node_id = node.node_info.id
# Return the end node context if we want to reuse a subpipeline. We do this
# because nodes dependent on a subpipeline use the subpipeline's end node
# to get their aritfacts from, so we reuse those artifacts.
if isinstance(node, node_proto_view.ComposablePipelineProtoView):
context_name = compiler_utils.end_node_context_name_from_subpipeline_id(
node_id
)
else:
context_name = compiler_utils.node_context_name(
self._pipeline_name, node_id
)
node_context = self._node_context_by_name.get(context_name)
if node_context is None:
raise LookupError(f'node context {node_context_name} not found in MLMD.')
raise LookupError(f'node context {context_name} not found in MLMD.')
return node_context

def _get_successful_executions(
self, node_id: str
self, node: node_proto_view.NodeProtoView
) -> List[metadata_store_pb2.Execution]:
"""Gets all successful Executions of a given node in a given pipeline run.
Args:
node_id: The node whose Executions to query.
node: The node whose Executions to query.
Returns:
All successful executions for that node at that run_id.
Raises:
LookupError: If no successful Execution was found.
"""
node_context = self._get_node_context(node_id)
node_context = self._get_node_context(node)
node_id = node.node_info.id
if not self._base_run_context:
raise LookupError(
f'No previous run is found for {node_id}. '
Expand Down Expand Up @@ -710,14 +703,14 @@ def _get_successful_executions(
def _cache_and_publish(
self,
existing_executions: List[metadata_store_pb2.Execution],
node_id: str,
node: node_proto_view.NodeProtoView,
):
"""Creates and publishes cache executions."""
if not existing_executions:
return

# Check if there are any previous attempts to cache and publish.
node_context = self._get_node_context(node_id)
node_context = self._get_node_context(node)
cached_execution_contexts = [
self._pipeline_context,
node_context,
Expand All @@ -728,11 +721,10 @@ def _cache_and_publish(
self._mlmd, contexts=[node_context, self._new_pipeline_run_context]
)
)

if not prev_cache_executions:
new_executions = []
new_cached_executions = []
for e in existing_executions:
new_executions.append(
new_cached_executions.append(
execution_lib.prepare_execution(
metadata_handle=self._mlmd,
execution_type=metadata_store_pb2.ExecutionType(id=e.type_id),
Expand All @@ -741,15 +733,17 @@ def _cache_and_publish(
)
)
else:
new_executions = [
new_cached_executions = [
e
for e in prev_cache_executions
if e.last_known_state != metadata_store_pb2.Execution.CACHED
]

if not new_executions:
logging.info(
'New cached executions to be published: %s', new_cached_executions
)
if not new_cached_executions:
return
if len(new_executions) != len(existing_executions):
if len(new_cached_executions) != len(existing_executions):
raise RuntimeError(
'The number of new executions is not the same as the number of'
' existing executions.'
Expand All @@ -762,7 +756,7 @@ def _cache_and_publish(
execution_publish_utils.publish_cached_executions(
self._mlmd,
contexts=cached_execution_contexts,
executions=new_executions,
executions=new_cached_executions,
output_artifacts_maps=output_artifacts_maps,
)

Expand All @@ -782,7 +776,7 @@ def put_parent_context(self):
child_id=self._new_pipeline_run_context.id,
)

def reuse_node_outputs(self, node_id: str):
def reuse_node_outputs(self, node: node_proto_view.NodeProtoView):
"""Makes the outputs of `node_id` available to new_pipeline_run_id."""
previous_executions = self._get_successful_executions(node_id)
self._cache_and_publish(previous_executions, node_id)
previous_executions = self._get_successful_executions(node)
self._cache_and_publish(previous_executions, node)
Loading

0 comments on commit 1bf082d

Please sign in to comment.