Skip to content

Commit

Permalink
Forwards all remaining keyword arguments to pass additional argument …
Browse files Browse the repository at this point in the history
…to the base class.

This is a well-defined convention in apache airflow custom operators. https://airflow.apache.org/docs/apache-airflow/stable/howto/custom-operator.html

Also apache-airflow 2.3.0 requires scheduler process to handle `unpause` or `trigger` command. So we changed tests to run the `scheduler` first. I refactored AirflowSubprocess into AirflowScheduler because AirflowSubprocess is only used for scheduler process.

PiperOrigin-RevId: 447368998
  • Loading branch information
jiyongjung authored and tfx-copybara committed May 9, 2022
1 parent c0e6668 commit f986f6b
Show file tree
Hide file tree
Showing 6 changed files with 56 additions and 41 deletions.
6 changes: 4 additions & 2 deletions RELEASE.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,15 @@
packaged) fails.
* Fixed `ElwcBigQueryExampleGen` data serializiation error that was causing an
assertion failure on Beam.
* Temporarily capped `apache-airflow` version to 2.2.x to avoid dependency
conflict. We will rollback this change once `kfp` releases a new version.
* Added dark mode styling support for InteractiveContext notebook formatters.
* (Python 3.9+) Supports `list` and `dict` in type definition of execution
properties.
* Populate Artifact proto `name` field when name is set on the Artifact python
object.
* Temporarily capped `apache-airflow` version to 2.2.x to avoid dependency
conflict. We will rollback this change once `kfp` releases a new version.
* Fixed a compatibility issue with apache-airflow 2.3.0 that is failing with
"unexpected keyword argument 'default_args'".

## Dependency Updates

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,22 +30,6 @@
from tfx.utils import test_case_utils


class AirflowSubprocess:
"""Launch an Airflow command."""

def __init__(self, airflow_args):
self._args = ['airflow'] + airflow_args
self._sub_process = None

def __enter__(self):
self._sub_process = subprocess.Popen(self._args)
return self

def __exit__(self, exception_type, exception_value, traceback): # pylint: disable=unused-argument
if self._sub_process:
self._sub_process.terminate()


# Number of seconds between polling pending task states.
_TASK_POLLING_INTERVAL_SEC = 10
# Maximum duration to allow no task state change.
Expand Down Expand Up @@ -188,26 +172,25 @@ def setUp(self):

# Initialize database.
subprocess.run(['airflow', 'db', 'init'], check=True)
subprocess.run(['airflow', 'dags', 'unpause', self._dag_id], check=True)

def testSimplePipeline(self):
subprocess.run([
'airflow',
'dags',
'trigger',
self._dag_id,
'-r',
self._run_id,
'-e',
self._execution_date,
],
check=True)
absl.logging.info('Dag triggered: %s', self._dag_id)
# We will use subprocess to start the DAG instead of webserver, so only
# need to start a scheduler on the background.
# Airflow scheduler should be launched after triggering the dag to mitigate
# a possible race condition between trigger_dag and scheduler.
with AirflowSubprocess(['scheduler']):
with airflow_test_utils.AirflowScheduler():
subprocess.run(['airflow', 'dags', 'unpause', self._dag_id], check=True)
subprocess.run([
'airflow',
'dags',
'trigger',
self._dag_id,
'-r',
self._run_id,
'-e',
self._execution_date,
],
check=True)
absl.logging.info('Dag triggered: %s', self._dag_id)

pending_tasks = set(self._all_tasks)
attempts = int(
_MAX_TASK_STATE_CHANGE_SEC / _TASK_POLLING_INTERVAL_SEC) + 1
Expand Down
7 changes: 5 additions & 2 deletions tfx/orchestration/airflow/airflow_component.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,8 @@ def __init__(self, *, parent_dag: models.DAG, component: base_node.BaseNode,
metadata_connection_config: metadata_store_pb2.ConnectionConfig,
beam_pipeline_args: List[str],
additional_pipeline_args: Dict[str, Any],
component_config: base_component_config.BaseComponentConfig):
component_config: base_component_config.BaseComponentConfig,
**kwargs):
"""Constructs an Airflow implementation of TFX component.
Args:
Expand All @@ -105,6 +106,7 @@ def __init__(self, *, parent_dag: models.DAG, component: base_node.BaseNode,
beam_pipeline_args: Pipeline arguments for Beam powered Components.
additional_pipeline_args: Additional pipeline args.
component_config: Component config to launch the component.
**kwargs: Addtional params passed to the base class, PythonOperator.
"""
# Prepare parameters to create TFX worker.
driver_args = data_types.DriverArgs(enable_cache=enable_cache)
Expand All @@ -129,4 +131,5 @@ def __init__(self, *, parent_dag: models.DAG, component: base_node.BaseNode,
# op_kwargs is a templated field for PythonOperator, which means Airflow
# will inspect the dictionary and resolve any templated fields.
op_kwargs={'exec_properties': exec_properties},
dag=parent_dag)
dag=parent_dag,
**kwargs)
6 changes: 4 additions & 2 deletions tfx/orchestration/airflow/airflow_component_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,14 +110,16 @@ def testAirflowComponent(self, mock_python_operator_init):
metadata_connection_config=self._metadata_connection_config,
beam_pipeline_args=[],
additional_pipeline_args={},
component_config=None)
component_config=None,
default_args={})

mock_python_operator_init.assert_called_once_with(
task_id=self._component.id,
provide_context=True,
python_callable=mock.ANY,
dag=self._parent_dag,
op_kwargs={'exec_properties': {}})
op_kwargs={'exec_properties': {}},
default_args={})

python_callable = mock_python_operator_init.call_args_list[0][1][
'python_callable']
Expand Down
23 changes: 23 additions & 0 deletions tfx/orchestration/airflow/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,26 @@ def delete_mysql_container(container_name: str):
container = client.containers.get(container_name)
container.remove(force=True)
client.close()


class AirflowScheduler:
"""Launch Airflow scheduler in the context."""

def __init__(self):
self._args = ['airflow', 'scheduler']
self._sub_process = None

def __enter__(self):
self._sub_process = subprocess.Popen(self._args)
while True:
time.sleep(10)
check_result = subprocess.run( # pylint: disable=subprocess-run-check
['airflow', 'jobs', 'check', '--job-type', 'SchedulerJob'])
if check_result.returncode == 0:
time.sleep(10) # Wait for a bit more until dag processing is completed.
break
return self

def __exit__(self, exception_type, exception_value, traceback): # pylint: disable=unused-argument
if self._sub_process:
self._sub_process.terminate()
8 changes: 5 additions & 3 deletions tfx/tools/cli/e2e/cli_airflow_e2e_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,9 +282,11 @@ def _valid_run_and_check(self, pipeline_name):

self._reload_airflow_dags()

result = self.runner.invoke(cli_group, [
'run', 'create', '--engine', 'airflow', '--pipeline_name', pipeline_name
])
with airflow_test_utils.AirflowScheduler():
result = self.runner.invoke(cli_group, [
'run', 'create', '--engine', 'airflow', '--pipeline_name',
pipeline_name
])

self.assertIn('Creating a run for pipeline: {}'.format(pipeline_name),
result.output)
Expand Down

0 comments on commit f986f6b

Please sign in to comment.