From e100ea4b744c6c266d511b7c4d893fef22e4aa72 Mon Sep 17 00:00:00 2001 From: Matt Usifer Date: Tue, 18 Feb 2025 17:50:32 -0500 Subject: [PATCH 1/2] Add dag_ids and exclude_dag_ids args to db_clean --- airflow/cli/cli_config.py | 4 + .../cli/commands/local_commands/db_command.py | 2 + airflow/utils/db_cleanup.py | 90 +++++++++--- tests/utils/test_db_cleanup.py | 134 +++++++++++++++++- 4 files changed, 210 insertions(+), 20 deletions(-) diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index 12664960b2495..0e777e3100bec 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -484,6 +484,8 @@ def string_lower_type(val): type=positive_int(allow_zero=False), help="Wait time between retries in seconds", ) +ARG_DAG_IDS = Arg(("dag_ids",), help="The ids of the dags to clean up") +ARG_EXCLUDE_DAG_IDS = Arg(("exclude_dag_ids",), help="The ids of the dags to exclude from clean up") # pool ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name") @@ -1528,6 +1530,8 @@ class GroupCommand(NamedTuple): ARG_VERBOSE, ARG_YES, ARG_DB_SKIP_ARCHIVE, + ARG_DAG_IDS, + ARG_EXCLUDE_DAG_IDS, ), ), ActionCommand( diff --git a/airflow/cli/commands/local_commands/db_command.py b/airflow/cli/commands/local_commands/db_command.py index 5a4f70cc60472..60c2b6121bf0b 100644 --- a/airflow/cli/commands/local_commands/db_command.py +++ b/airflow/cli/commands/local_commands/db_command.py @@ -283,6 +283,8 @@ def cleanup_tables(args): verbose=args.verbose, confirm=not args.yes, skip_archive=args.skip_archive, + dag_ids=args.dag_ids, + exclude_dag_ids=args.exclude_dag_ids, ) diff --git a/airflow/utils/db_cleanup.py b/airflow/utils/db_cleanup.py index 8c7b95a6e8940..8255130d40cf7 100644 --- a/airflow/utils/db_cleanup.py +++ b/airflow/utils/db_cleanup.py @@ -77,15 +77,26 @@ class _TableConfig: table_name: str recency_column_name: str extra_columns: list[str] | None = None + dag_id_column_name: str | None = None keep_last: bool = False keep_last_filters: Any | None = None keep_last_group_by: Any | None = None def __post_init__(self): self.recency_column = column(self.recency_column_name) - self.orm_model: Base = table( - self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column - ) + if self.dag_id_column_name is None: + self.dag_id_column = None + self.orm_model: Base = table( + self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column + ) + else: + self.dag_id_column = column(self.dag_id_column_name) + self.orm_model: Base = table( + self.table_name, + *[column(x) for x in self.extra_columns or []], + self.dag_id_column, + self.recency_column, + ) def __lt__(self, other): return self.table_name < other.table_name @@ -95,6 +106,7 @@ def readable_config(self): return { "table": self.orm_model.name, "recency_column": str(self.recency_column), + "dag_id_column": str(self.dag_id_column), "keep_last": self.keep_last, "keep_last_filters": [str(x) for x in self.keep_last_filters] if self.keep_last_filters else None, "keep_last_group_by": str(self.keep_last_group_by), @@ -102,31 +114,36 @@ def readable_config(self): config_list: list[_TableConfig] = [ - _TableConfig(table_name="job", recency_column_name="latest_heartbeat"), - _TableConfig(table_name="dag", recency_column_name="last_parsed_time"), + _TableConfig(table_name="job", recency_column_name="latest_heartbeat", dag_id_column_name="dag_id"), + _TableConfig(table_name="dag", recency_column_name="last_parsed_time", dag_id_column_name="dag_id"), _TableConfig( table_name="dag_run", recency_column_name="start_date", + dag_id_column_name="dag_id", extra_columns=["dag_id", "external_trigger"], keep_last=True, keep_last_filters=[column("external_trigger") == false()], keep_last_group_by=["dag_id"], ), - _TableConfig(table_name="asset_event", recency_column_name="timestamp"), + _TableConfig( + table_name="asset_event", recency_column_name="timestamp", dag_id_column_name="source_dag_id" + ), _TableConfig(table_name="import_error", recency_column_name="timestamp"), - _TableConfig(table_name="log", recency_column_name="dttm"), - _TableConfig(table_name="sla_miss", recency_column_name="timestamp"), - _TableConfig(table_name="task_instance", recency_column_name="start_date"), - _TableConfig(table_name="task_instance_history", recency_column_name="start_date"), - _TableConfig(table_name="task_reschedule", recency_column_name="start_date"), - _TableConfig(table_name="xcom", recency_column_name="timestamp"), - _TableConfig(table_name="_xcom_archive", recency_column_name="timestamp"), + _TableConfig(table_name="log", recency_column_name="dttm", dag_id_column_name="dag_id"), + _TableConfig(table_name="sla_miss", recency_column_name="timestamp", dag_id_column_name="dag_id"), + _TableConfig(table_name="task_instance", recency_column_name="start_date", dag_id_column_name="dag_id"), + _TableConfig( + table_name="task_instance_history", recency_column_name="start_date", dag_id_column_name="dag_id" + ), + _TableConfig(table_name="task_reschedule", recency_column_name="start_date", dag_id_column_name="dag_id"), + _TableConfig(table_name="xcom", recency_column_name="timestamp", dag_id_column_name="dag_id"), + _TableConfig(table_name="_xcom_archive", recency_column_name="timestamp", dag_id_column_name="dag_id"), _TableConfig(table_name="callback_request", recency_column_name="created_at"), _TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"), _TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"), _TableConfig(table_name="trigger", recency_column_name="created_date"), - _TableConfig(table_name="dag_version", recency_column_name="created_at"), - _TableConfig(table_name="deadline", recency_column_name="deadline"), + _TableConfig(table_name="dag_version", recency_column_name="created_at", dag_id_column_name="dag_id"), + _TableConfig(table_name="deadline", recency_column_name="deadline", dag_id_column_name="dag_id"), ] if conf.get("webserver", "session_backend") == "database": @@ -248,6 +265,9 @@ def _build_query( keep_last_group_by, clean_before_timestamp: DateTime, session: Session, + dag_id_column=None, + dag_ids: list[str] | None = None, + exclude_dag_ids: list[str] | None = None, **kwargs, ) -> Query: base_table_alias = "base" @@ -255,6 +275,18 @@ def _build_query( query = session.query(base_table).with_entities(text(f"{base_table_alias}.*")) base_table_recency_col = base_table.c[recency_column.name] conditions = [base_table_recency_col < clean_before_timestamp] + + if dag_ids or exclude_dag_ids: + if dag_id_column is None: + raise ValueError("Must provide a dag_id_column along with dag_ids and exclude_dag_ids") + + base_table_dag_id_col = base_table.c[dag_id_column.name] + + if dag_ids: + conditions.append(base_table_dag_id_col.in_(dag_ids)) + if exclude_dag_ids: + conditions.append(base_table_dag_id_col.not_in(exclude_dag_ids)) + if keep_last: max_date_col_name = "max_date_per_group" group_by_columns = [column(x) for x in keep_last_group_by] @@ -285,6 +317,9 @@ def _cleanup_table( keep_last_filters, keep_last_group_by, clean_before_timestamp: DateTime, + dag_id_column=None, + dag_ids=None, + exclude_dag_ids=None, dry_run: bool = True, verbose: bool = False, skip_archive: bool = False, @@ -297,6 +332,9 @@ def _cleanup_table( query = _build_query( orm_model=orm_model, recency_column=recency_column, + dag_id_column=dag_id_column, + dag_ids=dag_ids, + exclude_dag_ids=exclude_dag_ids, keep_last=keep_last, keep_last_filters=keep_last_filters, keep_last_group_by=keep_last_group_by, @@ -313,10 +351,14 @@ def _cleanup_table( session.commit() -def _confirm_delete(*, date: DateTime, tables: list[str]) -> None: +def _confirm_delete( + *, date: DateTime, tables: list[str], dag_ids: list[str] | None, exclude_dag_ids: list[str] | None +) -> None: for_tables = f" for tables {tables!r}" if tables else "" + for_dags = f" for the following dags: {dag_ids!r}" if dag_ids else "" + excluding_dags = f" excluding the following dags: {exclude_dag_ids!r}" if dag_ids else "" question = ( - f"You have requested that we purge all data prior to {date}{for_tables}.\n" + f"You have requested that we purge all data prior to {date}{for_tables}{for_dags}{excluding_dags}." f"This is irreversible. Consider backing up the tables first and / or doing a dry run " f"with option --dry-run.\n" f"Enter 'delete rows' (without quotes) to proceed." @@ -410,6 +452,8 @@ def run_cleanup( *, clean_before_timestamp: DateTime, table_names: list[str] | None = None, + dag_ids: list[str] | None = None, + exclude_dag_ids: list[str] | None = None, dry_run: bool = False, verbose: bool = False, confirm: bool = True, @@ -429,6 +473,9 @@ def run_cleanup( :param clean_before_timestamp: The timestamp before which data should be purged :param table_names: Optional. List of table names to perform maintenance on. If list not provided, will perform maintenance on all tables. + :param dag_ids: Optional. List of dag ids to perform maintenance on. If list not provided, + will perform maintenance on all dags. + :param exclude_dag_ids: Optional. List of dag ids to exclude from maintenance. :param dry_run: If true, print rows meeting deletion criteria :param verbose: If true, may provide more detailed output. :param confirm: Require user input to confirm before processing deletions. @@ -445,13 +492,20 @@ def run_cleanup( ) _print_config(configs=effective_config_dict) if not dry_run and confirm: - _confirm_delete(date=clean_before_timestamp, tables=sorted(effective_table_names)) + _confirm_delete( + date=clean_before_timestamp, + tables=sorted(effective_table_names), + dag_ids=dag_ids, + exclude_dag_ids=exclude_dag_ids, + ) existing_tables = reflect_tables(tables=None, session=session).tables for table_name, table_config in effective_config_dict.items(): if table_name in existing_tables: with _suppress_with_logging(table_name, session): _cleanup_table( clean_before_timestamp=clean_before_timestamp, + dag_ids=dag_ids, + exclude_dag_ids=exclude_dag_ids, dry_run=dry_run, verbose=verbose, **table_config.__dict__, diff --git a/tests/utils/test_db_cleanup.py b/tests/utils/test_db_cleanup.py index 63056337a5219..43f69d872c08b 100644 --- a/tests/utils/test_db_cleanup.py +++ b/tests/utils/test_db_cleanup.py @@ -275,6 +275,136 @@ def test__cleanup_table(self, table_name, date_add_kwargs, expected_to_delete, e else: raise Exception("unexpected") + @pytest.mark.parametrize( + "table_name, dag_ids, date_add_kwargs, expected_to_delete", + [ + pytest.param( + "task_instance", + dict( + dag_to_delete1=f"dag_to_delete_{uuid4()}", + dag_to_delete2=f"dag_to_delete_{uuid4()}", + dag_to_not_delete=f"dag_to_not_delete_{uuid4()}", + ), + dict(days=4), + 8, # should only delete 8 TI's (4 from each "dag_to_delete" DAG above) + id="only_delete_some_dag_ids", + ), + pytest.param( + "task_instance", + dict( + dag_to_implicitly_delete1=f"dag_to_implicitly_delete_{uuid4()}", + dag_to_implicitly_delete2=f"dag_to_implicitly_delete_{uuid4()}", + dag_to_implicitly_delete3=f"dag_to_implicitly_delete_{uuid4()}", + ), + dict(days=20), + 15, # All DAGs should be deleted since none were passed into the 'dag_ids' param + id="delete_all_dag_ids", + ), + pytest.param( + "dag_run", + dict( + dag_to_delete1=f"dag_to_delete_{uuid4()}", + dag_to_delete2=f"dag_to_delete_{uuid4()}", + dag_to_not_delete=f"dag_to_not_delete_{uuid4()}", + ), + dict(days=30), + 8, # delete 8 DagRuns, 4 from each dag_to_delete DAG (not 5, dag_run has keep_last=True) + id="delete_from_dag_run_table", + ), + ], + ) + def test__cleanup_dag_ids(self, table_name, dag_ids, date_add_kwargs, expected_to_delete): + """ + Verify that _cleanup_table actually deletes the rows it should with dag_ids parameter. + The _cleanup_table should delete all the dags that are in the dag_ids list and + if there are none, delete none. + """ + base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("America/New_York")) + num_tis = 5 + + for _, dag_id in dag_ids.items(): + create_tis(base_date=base_date, num_tis=num_tis, dag_id=dag_id) + + with create_session() as session: + clean_before_date = base_date.add(**date_add_kwargs) + _cleanup_table( + **config_dict[table_name].__dict__, + clean_before_timestamp=clean_before_date, + dry_run=False, + session=session, + dag_ids=[dag_id for dag_id in dag_ids.values() if "dag_to_delete" in dag_id], + ) + + expected_remaining = (num_tis * len(dag_ids)) - expected_to_delete + + model = config_dict[table_name].orm_model + assert len(session.query(model).all()) == expected_remaining + + @pytest.mark.parametrize( + "table_name, dag_ids, date_add_kwargs, expected_to_delete", + [ + pytest.param( + "task_instance", + dict( + dag_to_exclude1=f"dag_to_exclude_{uuid4()}", + dag_to_exclude2=f"dag_to_exclude_{uuid4()}", + dag_to_not_exclude=f"dag_to_not_exclude_{uuid4()}", + ), + dict(days=4), + 4, # should only delete 4 TI's (all from the dag_to_not_exclude DAG above) + id="only_exclude_some_dag_ids", + ), + pytest.param( + "task_instance", + dict( + dag_to_implicitly_delete1=f"dag_to_implicitly_delete_{uuid4()}", + dag_to_implicitly_delete2=f"dag_to_implicitly_delete_{uuid4()}", + dag_to_implicitly_delete3=f"dag_to_implicitly_delete_{uuid4()}", + ), + dict(days=20), + 15, # All DAGs should be deleted since none were excluded + id="delete_all_dag_ids", + ), + pytest.param( + "dag_run", + dict( + dag_to_exclude1=f"dag_to_exclude_{uuid4()}", + dag_to_exclude2=f"dag_to_exclude_{uuid4()}", + dag_to_not_exclude=f"dag_to_not_exclude_{uuid4()}", + ), + dict(days=20), + 4, # delete 4 DagRuns (not 5, dag_run has keep_last=True) + id="delete_from_dag_run_table", + ), + ], + ) + def test__cleanup_exclude_dag_ids(self, table_name, dag_ids, date_add_kwargs, expected_to_delete): + """ + Verify that _cleanup_table actually deletes the rows it should with the exclude_dag_ids parameter. + The _cleanup_table should delete all the dags that are not in the exclude_dag_ids list and + if there are none, delete all DAGs. + """ + base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("America/New_York")) + num_tis = 5 # each test case above has 3 DAGs with 5 ti's each + + for _, dag_id in dag_ids.items(): + create_tis(base_date=base_date, num_tis=num_tis, dag_id=dag_id) + + with create_session() as session: + clean_before_date = base_date.add(**date_add_kwargs) + _cleanup_table( + **config_dict[table_name].__dict__, + clean_before_timestamp=clean_before_date, + dry_run=False, + session=session, + exclude_dag_ids=[dag_id for dag_id in dag_ids.values() if "dag_to_exclude" in dag_id], + ) + + expected_remaining = (num_tis * len(dag_ids)) - expected_to_delete + + model = config_dict[table_name].orm_model + assert len(session.query(model).all()) == expected_remaining + @pytest.mark.parametrize( "skip_archive, expected_archives", [pytest.param(True, 1, id="skip_archive"), pytest.param(False, 2, id="do_archive")], @@ -549,9 +679,9 @@ def test_drop_archived_tables(self, mock_input, confirm_mock, inspect_mock, capl confirm_mock.assert_not_called() -def create_tis(base_date, num_tis, external_trigger=False): +def create_tis(base_date, num_tis, dag_id=None, external_trigger=False): with create_session() as session: - dag = DagModel(dag_id=f"test-dag_{uuid4()}") + dag = DagModel(dag_id=dag_id) if dag_id else DagModel(dag_id=f"test-dag_{uuid4()}") session.add(dag) for num in range(num_tis): start_date = base_date.add(days=num) From b8f93e965b6dc8a2bb4253d62845cd9da6527bdd Mon Sep 17 00:00:00 2001 From: Matt Usifer Date: Tue, 18 Feb 2025 18:02:15 -0500 Subject: [PATCH 2/2] Fix args --- airflow/cli/cli_config.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/airflow/cli/cli_config.py b/airflow/cli/cli_config.py index 0e777e3100bec..c1d1dea24170f 100644 --- a/airflow/cli/cli_config.py +++ b/airflow/cli/cli_config.py @@ -484,8 +484,8 @@ def string_lower_type(val): type=positive_int(allow_zero=False), help="Wait time between retries in seconds", ) -ARG_DAG_IDS = Arg(("dag_ids",), help="The ids of the dags to clean up") -ARG_EXCLUDE_DAG_IDS = Arg(("exclude_dag_ids",), help="The ids of the dags to exclude from clean up") +ARG_DAG_IDS = Arg(("--dag-ids",), help="The ids of the dags to clean up") +ARG_EXCLUDE_DAG_IDS = Arg(("--exclude-dag-ids",), help="The ids of the dags to exclude from clean up") # pool ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name")