This is an automated email from the ASF dual-hosted git repository.

kaxilnaik pushed a commit to branch v3-1-test
in repository https://gitbox.apache.org/repos/asf/airflow.git

commit 9bbd1cdeb41c45415dd55fdeb2cbc7eefaaf17c2
Author: Amogh Desai <[email protected]>
AuthorDate: Sat Sep 13 06:39:04 2025 -0600

    Fix xcom access in DAG processor callbacks for notifiers (#55542)
    
    (cherry picked from commit ba120e2db157c5b26c8ac7e0ada643333ebecae0)
---
 .../src/airflow/dag_processing/processor.py        |  52 ++++++++-
 .../tests/unit/dag_processing/test_processor.py    | 127 +++++++++++++++++++++
 2 files changed, 177 insertions(+), 2 deletions(-)

diff --git a/airflow-core/src/airflow/dag_processing/processor.py 
b/airflow-core/src/airflow/dag_processing/processor.py
index d4c73e61fea..616bd1abbe9 100644
--- a/airflow-core/src/airflow/dag_processing/processor.py
+++ b/airflow-core/src/airflow/dag_processing/processor.py
@@ -44,12 +44,20 @@ from airflow.sdk.execution_time.comms import (
     GetPreviousDagRun,
     GetPrevSuccessfulDagRun,
     GetVariable,
+    GetXCom,
+    GetXComCount,
+    GetXComSequenceItem,
+    GetXComSequenceSlice,
     MaskSecret,
     OKResponse,
     PreviousDagRunResult,
     PrevSuccessfulDagRunResult,
     PutVariable,
     VariableResult,
+    XComCountResponse,
+    XComResult,
+    XComSequenceIndexResult,
+    XComSequenceSliceResult,
 )
 from airflow.sdk.execution_time.supervisor import WatchedSubprocess
 from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, 
_send_task_error_email
@@ -107,6 +115,10 @@ ToManager = Annotated[
     | DeleteVariable
     | GetPrevSuccessfulDagRun
     | GetPreviousDagRun
+    | GetXCom
+    | GetXComCount
+    | GetXComSequenceItem
+    | GetXComSequenceSlice
     | MaskSecret,
     Field(discriminator="type"),
 ]
@@ -118,7 +130,11 @@ ToDagProcessor = Annotated[
     | PreviousDagRunResult
     | PrevSuccessfulDagRunResult
     | ErrorResponse
-    | OKResponse,
+    | OKResponse
+    | XComCountResponse
+    | XComResult
+    | XComSequenceIndexResult
+    | XComSequenceSliceResult,
     Field(discriminator="type"),
 ]
 
@@ -459,7 +475,11 @@ class DagFileProcessorProcess(WatchedSubprocess):
         self.send_msg(msg, request_id=0)
 
     def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, 
req_id: int) -> None:
-        from airflow.sdk.api.datamodels._generated import ConnectionResponse, 
VariableResponse
+        from airflow.sdk.api.datamodels._generated import (
+            ConnectionResponse,
+            VariableResponse,
+            XComSequenceIndexResponse,
+        )
 
         resp: BaseModel | None = None
         dump_opts = {}
@@ -496,6 +516,34 @@ class DagFileProcessorProcess(WatchedSubprocess):
             dagrun_result = 
PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp)
             resp = dagrun_result
             dump_opts = {"exclude_unset": True}
+        elif isinstance(msg, GetXCom):
+            xcom = self.client.xcoms.get(
+                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, 
msg.include_prior_dates
+            )
+            xcom_result = XComResult.from_xcom_response(xcom)
+            resp = xcom_result
+        elif isinstance(msg, GetXComCount):
+            resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, 
msg.key)
+        elif isinstance(msg, GetXComSequenceItem):
+            xcom = self.client.xcoms.get_sequence_item(
+                msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset
+            )
+            if isinstance(xcom, XComSequenceIndexResponse):
+                resp = XComSequenceIndexResult.from_response(xcom)
+            else:
+                resp = xcom
+        elif isinstance(msg, GetXComSequenceSlice):
+            xcoms = self.client.xcoms.get_sequence_slice(
+                msg.dag_id,
+                msg.run_id,
+                msg.task_id,
+                msg.key,
+                msg.start,
+                msg.stop,
+                msg.step,
+                msg.include_prior_dates,
+            )
+            resp = XComSequenceSliceResult.from_response(xcoms)
         elif isinstance(msg, MaskSecret):
             # Use sdk masker in dag processor and triggerer because those use 
the task sdk machinery
             from airflow.sdk.log import mask_secret
diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py 
b/airflow-core/tests/unit/dag_processing/test_processor.py
index 14427a6dee5..37d9236face 100644
--- a/airflow-core/tests/unit/dag_processing/test_processor.py
+++ b/airflow-core/tests/unit/dag_processing/test_processor.py
@@ -62,6 +62,12 @@ from airflow.sdk import DAG, BaseOperator
 from airflow.sdk.api.client import Client
 from airflow.sdk.api.datamodels._generated import DagRunState
 from airflow.sdk.execution_time import comms
+from airflow.sdk.execution_time.comms import (
+    GetXCom,
+    GetXComSequenceSlice,
+    XComResult,
+    XComSequenceSliceResult,
+)
 from airflow.utils.session import create_session
 from airflow.utils.state import TaskInstanceState
 
@@ -831,6 +837,127 @@ class TestExecuteDagCallbacks:
         # Should log warning about no callback found
         log.warning.assert_called_once_with("Callback requested, but dag 
didn't have any", dag_id="test_dag")
 
+    @pytest.mark.parametrize(
+        "xcom_operation,expected_message_type,expected_message,mock_response",
+        [
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="report_df", 
task_ids=task_ids),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="report_df",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="single_value", 
task_ids=["test_task"]),
+                "GetXComSequenceSlice",
+                GetXComSequenceSlice(
+                    key="single_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    start=None,
+                    stop=None,
+                    step=None,
+                    include_prior_dates=False,
+                ),
+                XComSequenceSliceResult(root=["test data"]),
+            ),
+            (
+                lambda ti, task_ids: ti.xcom_pull(key="direct_value", 
task_ids="test_task", map_indexes=None),
+                "GetXCom",
+                GetXCom(
+                    key="direct_value",
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    task_id="test_task",
+                    map_index=None,
+                    include_prior_dates=False,
+                ),
+                XComResult(
+                    key="direct_value",
+                    value="test",
+                ),
+            ),
+        ],
+    )
+    def test_notifier_xcom_operations_send_correct_messages(
+        self,
+        spy_agency,
+        mock_supervisor_comms,
+        xcom_operation,
+        expected_message_type,
+        expected_message,
+        mock_response,
+    ):
+        """Test that different XCom operations send correct message types"""
+
+        mock_supervisor_comms.send.return_value = mock_response
+
+        class TestNotifier:
+            def __call__(self, context):
+                ti = context["ti"]
+                dag = context["dag"]
+                task_ids = list(dag.task_dict)
+                xcom_operation(ti, task_ids)
+
+        with DAG(dag_id="test_dag", on_success_callback=TestNotifier()) as dag:
+            BaseOperator(task_id="test_task")
+
+        def fake_collect_dags(self, *args, **kwargs):
+            self.dags[dag.dag_id] = dag
+
+        spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, 
owner=DagBag)
+
+        dagbag = DagBag()
+        dagbag.collect_dags()
+
+        current_time = timezone.utcnow()
+        request = DagCallbackRequest(
+            filepath="test.py",
+            dag_id="test_dag",
+            run_id="test_run",
+            bundle_name="testing",
+            bundle_version=None,
+            context_from_server=DagRunContext(
+                dag_run=DRDataModel(
+                    dag_id="test_dag",
+                    run_id="test_run",
+                    logical_date=current_time,
+                    data_interval_start=current_time,
+                    data_interval_end=current_time,
+                    run_after=current_time,
+                    start_date=current_time,
+                    end_date=None,
+                    run_type="manual",
+                    state="success",
+                    consumed_asset_events=[],
+                ),
+                last_ti=TIDataModel(
+                    id=uuid.uuid4(),
+                    dag_id="test_dag",
+                    task_id="test_task",
+                    run_id="test_run",
+                    map_index=-1,
+                    try_number=1,
+                    dag_version_id=uuid.uuid4(),
+                ),
+            ),
+            is_failure_callback=False,
+            msg="Test success message",
+        )
+
+        _execute_dag_callbacks(dagbag, request, structlog.get_logger())
+
+        
mock_supervisor_comms.send.assert_called_once_with(msg=expected_message)
+
 
 class TestExecuteTaskCallbacks:
     """Test the _execute_task_callbacks function"""

Reply via email to