Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

AIP-72: Handle Custom XCom Backend on Task SDK #47339

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 47 additions & 1 deletion task_sdk/src/airflow/sdk/execution_time/task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,9 @@ def xcom_pull(
run_id: str | None = None,
) -> Any:
"""
Pull XComs that optionally meet certain criteria.
Pull XComs either from the API server (BaseXCom) or from the custom XCOM backend if configured.

The pull can be filtered optionally by certain criterion.

:param key: A key for the XCom. If provided, only XComs with matching
keys will be returned. The default key is ``'return_value'``, also
Expand Down Expand Up @@ -305,6 +307,16 @@ def xcom_pull(

xcoms = []
for t in task_ids:
if XCom:
value = XCom.get_one(
run_id=run_id,
key=key,
task_id=t,
dag_id=dag_id,
map_index=map_indexes,
)
xcoms.append(value)
continue
SUPERVISOR_COMMS.send_request(
log=log,
msg=GetXCom(
Expand Down Expand Up @@ -357,6 +369,15 @@ def _xcom_push(ti: RuntimeTaskInstance, key: str, value: Any, mapped_length: int
# consumers
from airflow.serialization.serde import serialize

if XCom:
XCom.set(
key=key,
value=value,
dag_id=ti.dag_id,
task_id=ti.task_id,
run_id=ti.run_id,
)
return
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

How are xcom backends meant to store their reference (I.e. an s3 path) without db access?

Copy link
Contributor Author

@amoghrajesh amoghrajesh Mar 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Right now its calling it out of models (set), I think it is high time we move the classes out of models/xcoms.py into execution_time.

# TODO: Move XCom serialization & deserialization to Task SDK
# https://github.com/apache/airflow/issues/45231

Expand Down Expand Up @@ -484,6 +505,8 @@ def send_request(self, log: Logger, msg: SendMsgType):
# 2. Execution (run task code, possibly send requests)
# 3. Shutdown and report status

XCom: Any = None


def startup() -> tuple[RuntimeTaskInstance, Logger]:
msg = SUPERVISOR_COMMS.get_message()
Expand All @@ -499,6 +522,9 @@ def startup() -> tuple[RuntimeTaskInstance, Logger]:
with _airflow_parsing_context_manager(dag_id=msg.ti.dag_id, task_id=msg.ti.task_id):
ti = parse(msg)
log.debug("DAG file parsed", file=msg.dag_rel_path)

global XCom
XCom = resolve_xcom_backend(log)
else:
raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")

Expand Down Expand Up @@ -778,6 +804,26 @@ def _push_xcom_if_needed(result: Any, ti: RuntimeTaskInstance, log: Logger):
_xcom_push(ti, "return_value", result, mapped_length=mapped_length)


def resolve_xcom_backend(log: Logger):
"""
Resolve a custom XCom class.

:returns: returns the custom XCom class if configured.
"""
from airflow.configuration import conf

clazz = conf.getimport("core", "xcom_backend")
if not clazz or clazz.__name__ == "BaseXCom":
log.info("Custom XCom backend not configured, using `BaseXCom` as fallback")
return None
# if not issubclass(clazz, BaseXCom):
# raise TypeError(
# f"Your custom XCom class `{clazz.__name__}` is not a subclass of `{BaseXCom.__name__}`."
# )
log.info("Custom XCom backend configured, using configured custom XCom backend", clazz=clazz)
return clazz


def finalize(
ti: RuntimeTaskInstance, state: TerminalTIState, log: Logger, error: BaseException | None = None
):
Expand Down
6 changes: 6 additions & 0 deletions task_sdk/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -263,6 +263,12 @@ def mock_supervisor_comms():
yield supervisor_comms


@pytest.fixture
def mock_xcom_backend():
with mock.patch("airflow.sdk.execution_time.task_runner.XCom", create=True) as xcom_backend:
yield xcom_backend


@pytest.fixture
def mocked_parse(spy_agency):
"""
Expand Down
76 changes: 76 additions & 0 deletions task_sdk/tests/execution_time/test_task_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -1383,6 +1383,82 @@ def execute(self, context):
f"Returned dictionary keys must be strings when using multiple_outputs, found 2 ({int}) instead"
)

def test_xcom_push_to_custom_xcom_backend(
self, create_runtime_ti, mock_supervisor_comms, mock_xcom_backend
):
"""Test that a task pushes a xcom to the custom xcom backend."""

class CustomOperator(BaseOperator):
def execute(self, context):
return "pushing to xcom backend!"

task = CustomOperator(task_id="pull_task")
runtime_ti = create_runtime_ti(task=task)

run(runtime_ti, log=mock.MagicMock())

mock_xcom_backend.set.assert_called_once_with(
key="return_value",
value="pushing to xcom backend!",
dag_id="test_dag",
task_id="pull_task",
run_id="test_run",
)

# assert that we didn't call the API when XCom backend is configured
assert not any(
x
== mock.call(
log=mock.ANY,
msg=SetXCom(
key="key",
value="pushing to xcom backend!",
dag_id="test_dag",
run_id="test_run",
task_id="pull_task",
map_index=-1,
),
)
for x in mock_supervisor_comms.send_request.call_args_list
)

def test_xcom_pull_from_custom_xcom_backend(
self, create_runtime_ti, mock_supervisor_comms, mock_xcom_backend
):
"""Test that a task pulls the expected XCom value if it exists, but from custom xcom backend."""

class CustomOperator(BaseOperator):
def execute(self, context):
value = context["ti"].xcom_pull(task_ids="pull_task", key="key")
print(f"Pulled XCom Value: {value}")

task = CustomOperator(task_id="pull_task")
runtime_ti = create_runtime_ti(task=task)
run(runtime_ti, log=mock.MagicMock())

mock_xcom_backend.get_one.assert_called_once_with(
key="key",
dag_id="test_dag",
task_id="pull_task",
run_id="test_run",
map_index=-1,
)

assert not any(
x
== mock.call(
log=mock.ANY,
msg=GetXCom(
key="key",
dag_id="test_dag",
run_id="test_run",
task_id="pull_task",
map_index=-1,
),
)
for x in mock_supervisor_comms.send_request.call_args_list
)


class TestDagParamRuntime:
DEFAULT_ARGS = {
Expand Down