This is an automated email from the ASF dual-hosted git repository. kaxilnaik pushed a commit to branch v3-1-test in repository https://gitbox.apache.org/repos/asf/airflow.git
commit 9bbd1cdeb41c45415dd55fdeb2cbc7eefaaf17c2 Author: Amogh Desai <[email protected]> AuthorDate: Sat Sep 13 06:39:04 2025 -0600 Fix xcom access in DAG processor callbacks for notifiers (#55542) (cherry picked from commit ba120e2db157c5b26c8ac7e0ada643333ebecae0) --- .../src/airflow/dag_processing/processor.py | 52 ++++++++- .../tests/unit/dag_processing/test_processor.py | 127 +++++++++++++++++++++ 2 files changed, 177 insertions(+), 2 deletions(-) diff --git a/airflow-core/src/airflow/dag_processing/processor.py b/airflow-core/src/airflow/dag_processing/processor.py index d4c73e61fea..616bd1abbe9 100644 --- a/airflow-core/src/airflow/dag_processing/processor.py +++ b/airflow-core/src/airflow/dag_processing/processor.py @@ -44,12 +44,20 @@ from airflow.sdk.execution_time.comms import ( GetPreviousDagRun, GetPrevSuccessfulDagRun, GetVariable, + GetXCom, + GetXComCount, + GetXComSequenceItem, + GetXComSequenceSlice, MaskSecret, OKResponse, PreviousDagRunResult, PrevSuccessfulDagRunResult, PutVariable, VariableResult, + XComCountResponse, + XComResult, + XComSequenceIndexResult, + XComSequenceSliceResult, ) from airflow.sdk.execution_time.supervisor import WatchedSubprocess from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance, _send_task_error_email @@ -107,6 +115,10 @@ ToManager = Annotated[ | DeleteVariable | GetPrevSuccessfulDagRun | GetPreviousDagRun + | GetXCom + | GetXComCount + | GetXComSequenceItem + | GetXComSequenceSlice | MaskSecret, Field(discriminator="type"), ] @@ -118,7 +130,11 @@ ToDagProcessor = Annotated[ | PreviousDagRunResult | PrevSuccessfulDagRunResult | ErrorResponse - | OKResponse, + | OKResponse + | XComCountResponse + | XComResult + | XComSequenceIndexResult + | XComSequenceSliceResult, Field(discriminator="type"), ] @@ -459,7 +475,11 @@ class DagFileProcessorProcess(WatchedSubprocess): self.send_msg(msg, request_id=0) def _handle_request(self, msg: ToManager, log: FilteringBoundLogger, req_id: int) -> None: - from airflow.sdk.api.datamodels._generated import ConnectionResponse, VariableResponse + from airflow.sdk.api.datamodels._generated import ( + ConnectionResponse, + VariableResponse, + XComSequenceIndexResponse, + ) resp: BaseModel | None = None dump_opts = {} @@ -496,6 +516,34 @@ class DagFileProcessorProcess(WatchedSubprocess): dagrun_result = PrevSuccessfulDagRunResult.from_dagrun_response(dagrun_resp) resp = dagrun_result dump_opts = {"exclude_unset": True} + elif isinstance(msg, GetXCom): + xcom = self.client.xcoms.get( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.map_index, msg.include_prior_dates + ) + xcom_result = XComResult.from_xcom_response(xcom) + resp = xcom_result + elif isinstance(msg, GetXComCount): + resp = self.client.xcoms.head(msg.dag_id, msg.run_id, msg.task_id, msg.key) + elif isinstance(msg, GetXComSequenceItem): + xcom = self.client.xcoms.get_sequence_item( + msg.dag_id, msg.run_id, msg.task_id, msg.key, msg.offset + ) + if isinstance(xcom, XComSequenceIndexResponse): + resp = XComSequenceIndexResult.from_response(xcom) + else: + resp = xcom + elif isinstance(msg, GetXComSequenceSlice): + xcoms = self.client.xcoms.get_sequence_slice( + msg.dag_id, + msg.run_id, + msg.task_id, + msg.key, + msg.start, + msg.stop, + msg.step, + msg.include_prior_dates, + ) + resp = XComSequenceSliceResult.from_response(xcoms) elif isinstance(msg, MaskSecret): # Use sdk masker in dag processor and triggerer because those use the task sdk machinery from airflow.sdk.log import mask_secret diff --git a/airflow-core/tests/unit/dag_processing/test_processor.py b/airflow-core/tests/unit/dag_processing/test_processor.py index 14427a6dee5..37d9236face 100644 --- a/airflow-core/tests/unit/dag_processing/test_processor.py +++ b/airflow-core/tests/unit/dag_processing/test_processor.py @@ -62,6 +62,12 @@ from airflow.sdk import DAG, BaseOperator from airflow.sdk.api.client import Client from airflow.sdk.api.datamodels._generated import DagRunState from airflow.sdk.execution_time import comms +from airflow.sdk.execution_time.comms import ( + GetXCom, + GetXComSequenceSlice, + XComResult, + XComSequenceSliceResult, +) from airflow.utils.session import create_session from airflow.utils.state import TaskInstanceState @@ -831,6 +837,127 @@ class TestExecuteDagCallbacks: # Should log warning about no callback found log.warning.assert_called_once_with("Callback requested, but dag didn't have any", dag_id="test_dag") + @pytest.mark.parametrize( + "xcom_operation,expected_message_type,expected_message,mock_response", + [ + ( + lambda ti, task_ids: ti.xcom_pull(key="report_df", task_ids=task_ids), + "GetXComSequenceSlice", + GetXComSequenceSlice( + key="report_df", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + start=None, + stop=None, + step=None, + include_prior_dates=False, + ), + XComSequenceSliceResult(root=["test data"]), + ), + ( + lambda ti, task_ids: ti.xcom_pull(key="single_value", task_ids=["test_task"]), + "GetXComSequenceSlice", + GetXComSequenceSlice( + key="single_value", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + start=None, + stop=None, + step=None, + include_prior_dates=False, + ), + XComSequenceSliceResult(root=["test data"]), + ), + ( + lambda ti, task_ids: ti.xcom_pull(key="direct_value", task_ids="test_task", map_indexes=None), + "GetXCom", + GetXCom( + key="direct_value", + dag_id="test_dag", + run_id="test_run", + task_id="test_task", + map_index=None, + include_prior_dates=False, + ), + XComResult( + key="direct_value", + value="test", + ), + ), + ], + ) + def test_notifier_xcom_operations_send_correct_messages( + self, + spy_agency, + mock_supervisor_comms, + xcom_operation, + expected_message_type, + expected_message, + mock_response, + ): + """Test that different XCom operations send correct message types""" + + mock_supervisor_comms.send.return_value = mock_response + + class TestNotifier: + def __call__(self, context): + ti = context["ti"] + dag = context["dag"] + task_ids = list(dag.task_dict) + xcom_operation(ti, task_ids) + + with DAG(dag_id="test_dag", on_success_callback=TestNotifier()) as dag: + BaseOperator(task_id="test_task") + + def fake_collect_dags(self, *args, **kwargs): + self.dags[dag.dag_id] = dag + + spy_agency.spy_on(DagBag.collect_dags, call_fake=fake_collect_dags, owner=DagBag) + + dagbag = DagBag() + dagbag.collect_dags() + + current_time = timezone.utcnow() + request = DagCallbackRequest( + filepath="test.py", + dag_id="test_dag", + run_id="test_run", + bundle_name="testing", + bundle_version=None, + context_from_server=DagRunContext( + dag_run=DRDataModel( + dag_id="test_dag", + run_id="test_run", + logical_date=current_time, + data_interval_start=current_time, + data_interval_end=current_time, + run_after=current_time, + start_date=current_time, + end_date=None, + run_type="manual", + state="success", + consumed_asset_events=[], + ), + last_ti=TIDataModel( + id=uuid.uuid4(), + dag_id="test_dag", + task_id="test_task", + run_id="test_run", + map_index=-1, + try_number=1, + dag_version_id=uuid.uuid4(), + ), + ), + is_failure_callback=False, + msg="Test success message", + ) + + _execute_dag_callbacks(dagbag, request, structlog.get_logger()) + + mock_supervisor_comms.send.assert_called_once_with(msg=expected_message) + class TestExecuteTaskCallbacks: """Test the _execute_task_callbacks function"""
