Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add new arguments to db_clean to explicitly include or exclude DAGs #46876

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions airflow/cli/cli_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -484,6 +484,8 @@ def string_lower_type(val):
type=positive_int(allow_zero=False),
help="Wait time between retries in seconds",
)
ARG_DAG_IDS = Arg(("--dag-ids",), help="The ids of the dags to clean up")
ARG_EXCLUDE_DAG_IDS = Arg(("--exclude-dag-ids",), help="The ids of the dags to exclude from clean up")

# pool
ARG_POOL_NAME = Arg(("pool",), metavar="NAME", help="Pool name")
Expand Down Expand Up @@ -1528,6 +1530,8 @@ class GroupCommand(NamedTuple):
ARG_VERBOSE,
ARG_YES,
ARG_DB_SKIP_ARCHIVE,
ARG_DAG_IDS,
ARG_EXCLUDE_DAG_IDS,
),
),
ActionCommand(
Expand Down
2 changes: 2 additions & 0 deletions airflow/cli/commands/local_commands/db_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -283,6 +283,8 @@ def cleanup_tables(args):
verbose=args.verbose,
confirm=not args.yes,
skip_archive=args.skip_archive,
dag_ids=args.dag_ids,
exclude_dag_ids=args.exclude_dag_ids,
)


Expand Down
90 changes: 72 additions & 18 deletions airflow/utils/db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,15 +77,26 @@ class _TableConfig:
table_name: str
recency_column_name: str
extra_columns: list[str] | None = None
dag_id_column_name: str | None = None
keep_last: bool = False
keep_last_filters: Any | None = None
keep_last_group_by: Any | None = None

def __post_init__(self):
self.recency_column = column(self.recency_column_name)
self.orm_model: Base = table(
self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column
)
if self.dag_id_column_name is None:
self.dag_id_column = None
self.orm_model: Base = table(
self.table_name, *[column(x) for x in self.extra_columns or []], self.recency_column
)
else:
self.dag_id_column = column(self.dag_id_column_name)
self.orm_model: Base = table(
self.table_name,
*[column(x) for x in self.extra_columns or []],
self.dag_id_column,
self.recency_column,
)

def __lt__(self, other):
return self.table_name < other.table_name
Expand All @@ -95,38 +106,44 @@ def readable_config(self):
return {
"table": self.orm_model.name,
"recency_column": str(self.recency_column),
"dag_id_column": str(self.dag_id_column),
"keep_last": self.keep_last,
"keep_last_filters": [str(x) for x in self.keep_last_filters] if self.keep_last_filters else None,
"keep_last_group_by": str(self.keep_last_group_by),
}


config_list: list[_TableConfig] = [
_TableConfig(table_name="job", recency_column_name="latest_heartbeat"),
_TableConfig(table_name="dag", recency_column_name="last_parsed_time"),
_TableConfig(table_name="job", recency_column_name="latest_heartbeat", dag_id_column_name="dag_id"),
_TableConfig(table_name="dag", recency_column_name="last_parsed_time", dag_id_column_name="dag_id"),
_TableConfig(
table_name="dag_run",
recency_column_name="start_date",
dag_id_column_name="dag_id",
extra_columns=["dag_id", "external_trigger"],
keep_last=True,
keep_last_filters=[column("external_trigger") == false()],
keep_last_group_by=["dag_id"],
),
_TableConfig(table_name="asset_event", recency_column_name="timestamp"),
_TableConfig(
table_name="asset_event", recency_column_name="timestamp", dag_id_column_name="source_dag_id"
),
_TableConfig(table_name="import_error", recency_column_name="timestamp"),
_TableConfig(table_name="log", recency_column_name="dttm"),
_TableConfig(table_name="sla_miss", recency_column_name="timestamp"),
_TableConfig(table_name="task_instance", recency_column_name="start_date"),
_TableConfig(table_name="task_instance_history", recency_column_name="start_date"),
_TableConfig(table_name="task_reschedule", recency_column_name="start_date"),
_TableConfig(table_name="xcom", recency_column_name="timestamp"),
_TableConfig(table_name="_xcom_archive", recency_column_name="timestamp"),
_TableConfig(table_name="log", recency_column_name="dttm", dag_id_column_name="dag_id"),
_TableConfig(table_name="sla_miss", recency_column_name="timestamp", dag_id_column_name="dag_id"),
_TableConfig(table_name="task_instance", recency_column_name="start_date", dag_id_column_name="dag_id"),
_TableConfig(
table_name="task_instance_history", recency_column_name="start_date", dag_id_column_name="dag_id"
),
_TableConfig(table_name="task_reschedule", recency_column_name="start_date", dag_id_column_name="dag_id"),
_TableConfig(table_name="xcom", recency_column_name="timestamp", dag_id_column_name="dag_id"),
_TableConfig(table_name="_xcom_archive", recency_column_name="timestamp", dag_id_column_name="dag_id"),
_TableConfig(table_name="callback_request", recency_column_name="created_at"),
_TableConfig(table_name="celery_taskmeta", recency_column_name="date_done"),
_TableConfig(table_name="celery_tasksetmeta", recency_column_name="date_done"),
_TableConfig(table_name="trigger", recency_column_name="created_date"),
_TableConfig(table_name="dag_version", recency_column_name="created_at"),
_TableConfig(table_name="deadline", recency_column_name="deadline"),
_TableConfig(table_name="dag_version", recency_column_name="created_at", dag_id_column_name="dag_id"),
_TableConfig(table_name="deadline", recency_column_name="deadline", dag_id_column_name="dag_id"),
]

if conf.get("webserver", "session_backend") == "database":
Expand Down Expand Up @@ -248,13 +265,28 @@ def _build_query(
keep_last_group_by,
clean_before_timestamp: DateTime,
session: Session,
dag_id_column=None,
dag_ids: list[str] | None = None,
exclude_dag_ids: list[str] | None = None,
**kwargs,
) -> Query:
base_table_alias = "base"
base_table = aliased(orm_model, name=base_table_alias)
query = session.query(base_table).with_entities(text(f"{base_table_alias}.*"))
base_table_recency_col = base_table.c[recency_column.name]
conditions = [base_table_recency_col < clean_before_timestamp]

if dag_ids or exclude_dag_ids:
if dag_id_column is None:
raise ValueError("Must provide a dag_id_column along with dag_ids and exclude_dag_ids")

base_table_dag_id_col = base_table.c[dag_id_column.name]

if dag_ids:
conditions.append(base_table_dag_id_col.in_(dag_ids))
if exclude_dag_ids:
conditions.append(base_table_dag_id_col.not_in(exclude_dag_ids))

if keep_last:
max_date_col_name = "max_date_per_group"
group_by_columns = [column(x) for x in keep_last_group_by]
Expand Down Expand Up @@ -285,6 +317,9 @@ def _cleanup_table(
keep_last_filters,
keep_last_group_by,
clean_before_timestamp: DateTime,
dag_id_column=None,
dag_ids=None,
exclude_dag_ids=None,
dry_run: bool = True,
verbose: bool = False,
skip_archive: bool = False,
Expand All @@ -297,6 +332,9 @@ def _cleanup_table(
query = _build_query(
orm_model=orm_model,
recency_column=recency_column,
dag_id_column=dag_id_column,
dag_ids=dag_ids,
exclude_dag_ids=exclude_dag_ids,
keep_last=keep_last,
keep_last_filters=keep_last_filters,
keep_last_group_by=keep_last_group_by,
Expand All @@ -313,10 +351,14 @@ def _cleanup_table(
session.commit()


def _confirm_delete(*, date: DateTime, tables: list[str]) -> None:
def _confirm_delete(
*, date: DateTime, tables: list[str], dag_ids: list[str] | None, exclude_dag_ids: list[str] | None
) -> None:
for_tables = f" for tables {tables!r}" if tables else ""
for_dags = f" for the following dags: {dag_ids!r}" if dag_ids else ""
excluding_dags = f" excluding the following dags: {exclude_dag_ids!r}" if dag_ids else ""
question = (
f"You have requested that we purge all data prior to {date}{for_tables}.\n"
f"You have requested that we purge all data prior to {date}{for_tables}{for_dags}{excluding_dags}."
f"This is irreversible. Consider backing up the tables first and / or doing a dry run "
f"with option --dry-run.\n"
f"Enter 'delete rows' (without quotes) to proceed."
Expand Down Expand Up @@ -410,6 +452,8 @@ def run_cleanup(
*,
clean_before_timestamp: DateTime,
table_names: list[str] | None = None,
dag_ids: list[str] | None = None,
exclude_dag_ids: list[str] | None = None,
dry_run: bool = False,
verbose: bool = False,
confirm: bool = True,
Expand All @@ -429,6 +473,9 @@ def run_cleanup(
:param clean_before_timestamp: The timestamp before which data should be purged
:param table_names: Optional. List of table names to perform maintenance on. If list not provided,
will perform maintenance on all tables.
:param dag_ids: Optional. List of dag ids to perform maintenance on. If list not provided,
will perform maintenance on all dags.
:param exclude_dag_ids: Optional. List of dag ids to exclude from maintenance.
:param dry_run: If true, print rows meeting deletion criteria
:param verbose: If true, may provide more detailed output.
:param confirm: Require user input to confirm before processing deletions.
Expand All @@ -445,13 +492,20 @@ def run_cleanup(
)
_print_config(configs=effective_config_dict)
if not dry_run and confirm:
_confirm_delete(date=clean_before_timestamp, tables=sorted(effective_table_names))
_confirm_delete(
date=clean_before_timestamp,
tables=sorted(effective_table_names),
dag_ids=dag_ids,
exclude_dag_ids=exclude_dag_ids,
)
existing_tables = reflect_tables(tables=None, session=session).tables
for table_name, table_config in effective_config_dict.items():
if table_name in existing_tables:
with _suppress_with_logging(table_name, session):
_cleanup_table(
clean_before_timestamp=clean_before_timestamp,
dag_ids=dag_ids,
exclude_dag_ids=exclude_dag_ids,
dry_run=dry_run,
verbose=verbose,
**table_config.__dict__,
Expand Down
134 changes: 132 additions & 2 deletions tests/utils/test_db_cleanup.py
Original file line number Diff line number Diff line change
Expand Up @@ -275,6 +275,136 @@ def test__cleanup_table(self, table_name, date_add_kwargs, expected_to_delete, e
else:
raise Exception("unexpected")

@pytest.mark.parametrize(
"table_name, dag_ids, date_add_kwargs, expected_to_delete",
[
pytest.param(
"task_instance",
dict(
dag_to_delete1=f"dag_to_delete_{uuid4()}",
dag_to_delete2=f"dag_to_delete_{uuid4()}",
dag_to_not_delete=f"dag_to_not_delete_{uuid4()}",
),
dict(days=4),
8, # should only delete 8 TI's (4 from each "dag_to_delete" DAG above)
id="only_delete_some_dag_ids",
),
pytest.param(
"task_instance",
dict(
dag_to_implicitly_delete1=f"dag_to_implicitly_delete_{uuid4()}",
dag_to_implicitly_delete2=f"dag_to_implicitly_delete_{uuid4()}",
dag_to_implicitly_delete3=f"dag_to_implicitly_delete_{uuid4()}",
),
dict(days=20),
15, # All DAGs should be deleted since none were passed into the 'dag_ids' param
id="delete_all_dag_ids",
),
pytest.param(
"dag_run",
dict(
dag_to_delete1=f"dag_to_delete_{uuid4()}",
dag_to_delete2=f"dag_to_delete_{uuid4()}",
dag_to_not_delete=f"dag_to_not_delete_{uuid4()}",
),
dict(days=30),
8, # delete 8 DagRuns, 4 from each dag_to_delete DAG (not 5, dag_run has keep_last=True)
id="delete_from_dag_run_table",
),
],
)
def test__cleanup_dag_ids(self, table_name, dag_ids, date_add_kwargs, expected_to_delete):
"""
Verify that _cleanup_table actually deletes the rows it should with dag_ids parameter.
The _cleanup_table should delete all the dags that are in the dag_ids list and
if there are none, delete none.
"""
base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("America/New_York"))
num_tis = 5

for _, dag_id in dag_ids.items():
create_tis(base_date=base_date, num_tis=num_tis, dag_id=dag_id)

with create_session() as session:
clean_before_date = base_date.add(**date_add_kwargs)
_cleanup_table(
**config_dict[table_name].__dict__,
clean_before_timestamp=clean_before_date,
dry_run=False,
session=session,
dag_ids=[dag_id for dag_id in dag_ids.values() if "dag_to_delete" in dag_id],
)

expected_remaining = (num_tis * len(dag_ids)) - expected_to_delete

model = config_dict[table_name].orm_model
assert len(session.query(model).all()) == expected_remaining

@pytest.mark.parametrize(
"table_name, dag_ids, date_add_kwargs, expected_to_delete",
[
pytest.param(
"task_instance",
dict(
dag_to_exclude1=f"dag_to_exclude_{uuid4()}",
dag_to_exclude2=f"dag_to_exclude_{uuid4()}",
dag_to_not_exclude=f"dag_to_not_exclude_{uuid4()}",
),
dict(days=4),
4, # should only delete 4 TI's (all from the dag_to_not_exclude DAG above)
id="only_exclude_some_dag_ids",
),
pytest.param(
"task_instance",
dict(
dag_to_implicitly_delete1=f"dag_to_implicitly_delete_{uuid4()}",
dag_to_implicitly_delete2=f"dag_to_implicitly_delete_{uuid4()}",
dag_to_implicitly_delete3=f"dag_to_implicitly_delete_{uuid4()}",
),
dict(days=20),
15, # All DAGs should be deleted since none were excluded
id="delete_all_dag_ids",
),
pytest.param(
"dag_run",
dict(
dag_to_exclude1=f"dag_to_exclude_{uuid4()}",
dag_to_exclude2=f"dag_to_exclude_{uuid4()}",
dag_to_not_exclude=f"dag_to_not_exclude_{uuid4()}",
),
dict(days=20),
4, # delete 4 DagRuns (not 5, dag_run has keep_last=True)
id="delete_from_dag_run_table",
),
],
)
def test__cleanup_exclude_dag_ids(self, table_name, dag_ids, date_add_kwargs, expected_to_delete):
"""
Verify that _cleanup_table actually deletes the rows it should with the exclude_dag_ids parameter.
The _cleanup_table should delete all the dags that are not in the exclude_dag_ids list and
if there are none, delete all DAGs.
"""
base_date = pendulum.DateTime(2022, 1, 1, tzinfo=pendulum.timezone("America/New_York"))
num_tis = 5 # each test case above has 3 DAGs with 5 ti's each

for _, dag_id in dag_ids.items():
create_tis(base_date=base_date, num_tis=num_tis, dag_id=dag_id)

with create_session() as session:
clean_before_date = base_date.add(**date_add_kwargs)
_cleanup_table(
**config_dict[table_name].__dict__,
clean_before_timestamp=clean_before_date,
dry_run=False,
session=session,
exclude_dag_ids=[dag_id for dag_id in dag_ids.values() if "dag_to_exclude" in dag_id],
)

expected_remaining = (num_tis * len(dag_ids)) - expected_to_delete

model = config_dict[table_name].orm_model
assert len(session.query(model).all()) == expected_remaining

@pytest.mark.parametrize(
"skip_archive, expected_archives",
[pytest.param(True, 1, id="skip_archive"), pytest.param(False, 2, id="do_archive")],
Expand Down Expand Up @@ -549,9 +679,9 @@ def test_drop_archived_tables(self, mock_input, confirm_mock, inspect_mock, capl
confirm_mock.assert_not_called()


def create_tis(base_date, num_tis, external_trigger=False):
def create_tis(base_date, num_tis, dag_id=None, external_trigger=False):
with create_session() as session:
dag = DagModel(dag_id=f"test-dag_{uuid4()}")
dag = DagModel(dag_id=dag_id) if dag_id else DagModel(dag_id=f"test-dag_{uuid4()}")
session.add(dag)
for num in range(num_tis):
start_date = base_date.add(days=num)
Expand Down
Loading