Copilot commented on code in PR #64770:
URL: https://github.com/apache/airflow/pull/64770#discussion_r3066476162


##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,10 +197,132 @@ def __init__(
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.cancel_on_kill = cancel_on_kill
 
     def hook(self) -> AwsGenericHook:
         return EmrContainerHook(aws_conn_id=self.aws_conn_id)
 
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for the current trigger (Airflow 2.x 
compatibility)."""
+            from sqlalchemy import select
+
+            query = select(TaskInstance).where(
+                TaskInstance.dag_id == self.task_instance.dag_id,
+                TaskInstance.task_id == self.task_instance.task_id,
+                TaskInstance.run_id == self.task_instance.run_id,
+                TaskInstance.map_index == self.task_instance.map_index,
+            )
+            task_instance = session.scalars(query).one_or_none()
+            if task_instance is None:
+                raise ValueError(
+                    f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                    f"task_id: {self.task_instance.task_id}, "
+                    f"run_id: {self.task_instance.run_id} and "
+                    f"map_index: {self.task_instance.map_index} is not found"
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        """Get the current state of the task instance (Airflow 3.x)."""
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except (KeyError, TypeError) as e:
+            raise ValueError(
+                f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                f"task_id: {self.task_instance.task_id}, "
+                f"run_id: {self.task_instance.run_id} and "
+                f"map_index: {self.task_instance.map_index} is not found"
+            ) from e
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
+        """
+        Whether it is safe to cancel the EMR container job.
+
+        Returns True if task is NOT DEFERRED (user-initiated cancellation).
+        Returns False if task is DEFERRED (triggerer restart - don't cancel 
job).
+        """
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]
+            task_state = task_instance.state
+        return task_state != TaskInstanceState.DEFERRED
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """
+        Run the trigger and wait for the job to complete.
+
+        If the task is cancelled while waiting, attempt to cancel the EMR 
container job
+        if cancel_on_kill is enabled and it's safe to do so.
+        """
+        hook: EmrContainerHook = self.hook()  # type: ignore[assignment]
+        try:
+            async with await hook.get_async_conn() as client:
+                waiter = hook.get_waiter(
+                    self.waiter_name,
+                    deferrable=True,
+                    client=client,
+                    config_overrides=self.waiter_config_overrides,
+                )
+                await async_wait(
+                    waiter,
+                    self.waiter_delay,
+                    self.attempts,
+                    self.waiter_args,
+                    self.failure_message,
+                    self.status_message,
+                    self.status_queries,
+                )
+            yield TriggerEvent({"status": "success", self.return_key: 
self.return_value})
+        except asyncio.CancelledError:
+            try:
+                if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
+                    self.log.info(
+                        "Task was cancelled. Cancelling EMR container job. "
+                        "Virtual Cluster ID: %s, Job ID: %s",
+                        self.virtual_cluster_id,
+                        self.job_id,
+                    )
+                    try:
+                        hook.stop_query(self.job_id)
+                        self.log.info("EMR container job %s cancelled 
successfully.", self.job_id)
+                    except Exception:
+                        self.log.exception(
+                            "Failed to cancel EMR container job %s. The job 
may still be running.",
+                            self.job_id,
+                        )
+                else:
+                    self.log.info(
+                        "Trigger may have shutdown or cancel_on_kill is 
disabled. "
+                        "Skipping job cancellation. Virtual Cluster ID: %s, 
Job ID: %s",
+                        self.virtual_cluster_id,
+                        self.job_id,
+                    )
+            except asyncio.CancelledError:
+                raise
+            except Exception:
+                self.log.exception(
+                    "Error during cancellation check for EMR container job %s. 
The job may still be running.",
+                    self.job_id,
+                )
+            raise
+        except AirflowException as e:
+            yield TriggerEvent({"status": "error", "message": str(e), 
self.return_key: self.return_value})

Review Comment:
   Overriding `run()` narrows error handling to only `AirflowException`. Any 
other exception from `get_async_conn()`, `get_waiter()`, or `async_wait()` 
(e.g., boto/connection errors that are not wrapped as `AirflowException`) will 
now bubble out of the trigger instead of producing an error `TriggerEvent`, 
changing behavior from the base waiter trigger. Add a final `except Exception 
as e:` branch that yields an `{\"status\": \"error\"...}` event (while keeping 
`CancelledError` re-raised) to preserve expected deferrable error propagation.
   ```suggestion
               yield TriggerEvent({"status": "error", "message": str(e), 
self.return_key: self.return_value})
           except Exception as e:
               yield TriggerEvent({"status": "error", "message": str(e), 
self.return_key: self.return_value})
   ```



##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,10 +197,132 @@ def __init__(
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.cancel_on_kill = cancel_on_kill
 
     def hook(self) -> AwsGenericHook:
         return EmrContainerHook(aws_conn_id=self.aws_conn_id)
 
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for the current trigger (Airflow 2.x 
compatibility)."""
+            from sqlalchemy import select
+
+            query = select(TaskInstance).where(
+                TaskInstance.dag_id == self.task_instance.dag_id,
+                TaskInstance.task_id == self.task_instance.task_id,
+                TaskInstance.run_id == self.task_instance.run_id,
+                TaskInstance.map_index == self.task_instance.map_index,
+            )
+            task_instance = session.scalars(query).one_or_none()
+            if task_instance is None:
+                raise ValueError(
+                    f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                    f"task_id: {self.task_instance.task_id}, "
+                    f"run_id: {self.task_instance.run_id} and "
+                    f"map_index: {self.task_instance.map_index} is not found"
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        """Get the current state of the task instance (Airflow 3.x)."""
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except (KeyError, TypeError) as e:
+            raise ValueError(
+                f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                f"task_id: {self.task_instance.task_id}, "
+                f"run_id: {self.task_instance.run_id} and "
+                f"map_index: {self.task_instance.map_index} is not found"
+            ) from e
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
+        """
+        Whether it is safe to cancel the EMR container job.
+
+        Returns True if task is NOT DEFERRED (user-initiated cancellation).
+        Returns False if task is DEFERRED (triggerer restart - don't cancel 
job).
+        """
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]
+            task_state = task_instance.state
+        return task_state != TaskInstanceState.DEFERRED
+
+    async def run(self) -> AsyncIterator[TriggerEvent]:
+        """
+        Run the trigger and wait for the job to complete.
+
+        If the task is cancelled while waiting, attempt to cancel the EMR 
container job
+        if cancel_on_kill is enabled and it's safe to do so.
+        """
+        hook: EmrContainerHook = self.hook()  # type: ignore[assignment]
+        try:
+            async with await hook.get_async_conn() as client:
+                waiter = hook.get_waiter(
+                    self.waiter_name,
+                    deferrable=True,
+                    client=client,
+                    config_overrides=self.waiter_config_overrides,
+                )
+                await async_wait(
+                    waiter,
+                    self.waiter_delay,
+                    self.attempts,
+                    self.waiter_args,
+                    self.failure_message,
+                    self.status_message,
+                    self.status_queries,
+                )
+            yield TriggerEvent({"status": "success", self.return_key: 
self.return_value})
+        except asyncio.CancelledError:
+            try:
+                if self.job_id and self.cancel_on_kill and await 
self.safe_to_cancel():
+                    self.log.info(
+                        "Task was cancelled. Cancelling EMR container job. "
+                        "Virtual Cluster ID: %s, Job ID: %s",
+                        self.virtual_cluster_id,
+                        self.job_id,
+                    )
+                    try:
+                        hook.stop_query(self.job_id)

Review Comment:
   `safe_to_cancel()` calls a synchronous DB query (`get_task_instance()`) from 
an async context (Airflow 2.x path), and the cancellation path calls 
synchronous `hook.stop_query()` from the trigger event loop. Both can block the 
triggerer loop and degrade reliability under load. Prefer running these sync 
operations via `sync_to_async(...)` (or an async-capable hook/client method) so 
cancellation checks and stop calls don't block other triggers.
   ```suggestion
                           await sync_to_async(hook.stop_query)(self.job_id)
   ```



##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,10 +197,132 @@ def __init__(
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.cancel_on_kill = cancel_on_kill
 
     def hook(self) -> AwsGenericHook:
         return EmrContainerHook(aws_conn_id=self.aws_conn_id)
 
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for the current trigger (Airflow 2.x 
compatibility)."""
+            from sqlalchemy import select
+
+            query = select(TaskInstance).where(
+                TaskInstance.dag_id == self.task_instance.dag_id,
+                TaskInstance.task_id == self.task_instance.task_id,
+                TaskInstance.run_id == self.task_instance.run_id,
+                TaskInstance.map_index == self.task_instance.map_index,
+            )
+            task_instance = session.scalars(query).one_or_none()
+            if task_instance is None:
+                raise ValueError(
+                    f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                    f"task_id: {self.task_instance.task_id}, "
+                    f"run_id: {self.task_instance.run_id} and "
+                    f"map_index: {self.task_instance.map_index} is not found"
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        """Get the current state of the task instance (Airflow 3.x)."""
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+        except (KeyError, TypeError) as e:
+            raise ValueError(
+                f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                f"task_id: {self.task_instance.task_id}, "
+                f"run_id: {self.task_instance.run_id} and "
+                f"map_index: {self.task_instance.map_index} is not found"
+            ) from e
+        return task_state
+
+    async def safe_to_cancel(self) -> bool:
+        """
+        Whether it is safe to cancel the EMR container job.
+
+        Returns True if task is NOT DEFERRED (user-initiated cancellation).
+        Returns False if task is DEFERRED (triggerer restart - don't cancel 
job).
+        """
+        if AIRFLOW_V_3_0_PLUS:
+            task_state = await self.get_task_state()
+        else:
+            task_instance = self.get_task_instance()  # type: ignore[call-arg]

Review Comment:
   `safe_to_cancel()` calls a synchronous DB query (`get_task_instance()`) from 
an async context (Airflow 2.x path), and the cancellation path calls 
synchronous `hook.stop_query()` from the trigger event loop. Both can block the 
triggerer loop and degrade reliability under load. Prefer running these sync 
operations via `sync_to_async(...)` (or an async-capable hook/client method) so 
cancellation checks and stop calls don't block other triggers.
   ```suggestion
               task_instance = await sync_to_async(self.get_task_instance, 
thread_sensitive=True)()  # type: ignore[call-arg]
   ```



##########
providers/amazon/src/airflow/providers/amazon/aws/operators/emr.py:
##########
@@ -567,13 +567,15 @@ def execute(self, context: Context) -> str | None:
                     aws_conn_id=self.aws_conn_id,
                     waiter_delay=self.poll_interval,
                     waiter_max_attempts=self.max_polling_attempts,
+                    cancel_on_kill=True,
                 )
                 if self.max_polling_attempts
                 else EmrContainerTrigger(
                     virtual_cluster_id=self.virtual_cluster_id,
                     job_id=self.job_id,
                     aws_conn_id=self.aws_conn_id,
                     waiter_delay=self.poll_interval,
+                    cancel_on_kill=True,
                 ),

Review Comment:
   The PR description states a new `cancel_on_kill` parameter is added \"for 
opt-out\", but `EmrContainerOperator` hard-codes `cancel_on_kill=True` when 
instantiating the trigger, leaving no opt-out path for operator users. To match 
the stated behavior, consider adding an operator parameter (e.g., 
`cancel_on_kill: bool = True`) and pass it through to `EmrContainerTrigger`.



##########
providers/amazon/tests/unit/amazon/aws/operators/test_emr_containers.py:
##########
@@ -162,6 +162,19 @@ def test_operator_defer_with_timeout(self, 
mock_submit_job, mock_check_query_sta
         assert trigger.waiter_delay == self.emr_container.poll_interval
         assert trigger.attempts == self.emr_container.max_polling_attempts
 
+    @mock.patch.object(EmrContainerHook, "stop_query")
+    def test_execute_complete_cancels_job_on_failure(self, mock_stop_query):
+        self.emr_container.job_id = "test_job_id"
+        event = {"status": "error", "message": "Job timed out", "job_id": 
"test_job_id"}
+        with pytest.raises(AirflowException):
+            self.emr_container.execute_complete(context=None, event=event)
+        mock_stop_query.assert_called_once_with("test_job_id")
+

Review Comment:
   Current coverage verifies `stop_query()` is invoked on failure, but does not 
cover the branch where `stop_query()` raises and the operator should still 
raise the original `AirflowException` (i.e., cancellation failure must not mask 
the task failure). Add a unit test where `mock_stop_query.side_effect = 
Exception(...)` and assert `execute_complete()` still raises `AirflowException` 
while logging the cancellation error.
   ```suggestion
   
       @mock.patch.object(EmrContainerHook, "stop_query")
       def test_execute_complete_raises_original_error_when_cancel_fails(self, 
mock_stop_query, caplog):
           self.emr_container.job_id = "test_job_id"
           mock_stop_query.side_effect = Exception("Failed to cancel job")
           event = {"status": "error", "message": "Job timed out", "job_id": 
"test_job_id"}
   
           with pytest.raises(AirflowException):
               self.emr_container.execute_complete(context=None, event=event)
   
           mock_stop_query.assert_called_once_with("test_job_id")
           assert "Failed to cancel job" in caplog.text
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to