Copilot commented on code in PR #64770:
URL: https://github.com/apache/airflow/pull/64770#discussion_r3067080788


##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,9 +197,133 @@ def __init__(
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.cancel_on_kill = cancel_on_kill
 
     def hook(self) -> AwsGenericHook:
-        return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+        return EmrContainerHook(aws_conn_id=self.aws_conn_id, 
virtual_cluster_id=self.virtual_cluster_id)
+
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for the current trigger (Airflow 2.x 
compatibility)."""
+            from sqlalchemy import select
+
+            query = select(TaskInstance).where(

Review Comment:
   Imports are being done inside method bodies (`from sqlalchemy import 
select`). Per Airflow codebase guidelines, imports should be at module scope 
unless there is a strong circular/lazy-load reason. Consider moving this import 
to the existing `if not AIRFLOW_V_3_0_PLUS:` module-level block to avoid 
repeated imports and keep import behavior consistent.



##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
             "waiter_delay": 30,
             "waiter_max_attempts": sys.maxsize,
             "aws_conn_id": "aws_default",
+            "cancel_on_kill": True,
         }
 
+    def test_serialization_includes_cancel_on_kill(self):
+        """Test that cancel_on_kill=True is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is True
+
+    def test_serialization_cancel_on_kill_false(self):
+        """Test that cancel_on_kill=False is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=False,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is False
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_cancels_job_on_killed_when_safe(self, 
mock_safe_to_cancel, mock_async_wait):
+        """
+        Test that EmrContainerTrigger cancels the job when task is killed
+        and safe_to_cancel returns True.
+        """
+        mock_safe_to_cancel.return_value = True
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+        mock_hook.stop_query.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+        with mock.patch.object(trigger, "hook", return_value=mock_hook):
+            with pytest.raises(asyncio.CancelledError):
+                async for _ in trigger.run():
+                    pass
+
+        mock_hook.stop_query.assert_called_once_with("test_job")
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_no_cancel_when_unsafe(self, mock_safe_to_cancel, 
mock_async_wait):
+        """
+        Test that EmrContainerTrigger does NOT cancel the job when
+        safe_to_cancel returns False (e.g., triggerer shutdown).
+        """
+        mock_safe_to_cancel.return_value = False
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+        with mock.patch.object(trigger, "hook", return_value=mock_hook):
+            with pytest.raises(asyncio.CancelledError):
+                async for _ in trigger.run():
+                    pass
+
+        mock_hook.stop_query.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_no_cancel_when_disabled(self, mock_safe_to_cancel, 
mock_async_wait):
+        """
+        Test that EmrContainerTrigger does NOT cancel the job when
+        cancel_on_kill=False.
+        """
+        mock_safe_to_cancel.return_value = True
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=False,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+        with mock.patch.object(trigger, "hook", return_value=mock_hook):
+            with pytest.raises(asyncio.CancelledError):
+                async for _ in trigger.run():
+                    pass
+
+        mock_hook.stop_query.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    async def test_run_yields_error_on_airflow_exception(self, 
mock_async_wait):
+        """Test that an AirflowException yields an error TriggerEvent."""
+        mock_async_wait.side_effect = AirflowException("Something went wrong")
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+

Review Comment:
   This new test also relies on unspecced `MagicMock` instances for the 
hook/client/context manager. Using autospecced mocks would better validate the 
`EmrContainerHook` API and avoid silent passes if production method 
names/signatures change.



##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
             "waiter_delay": 30,
             "waiter_max_attempts": sys.maxsize,
             "aws_conn_id": "aws_default",
+            "cancel_on_kill": True,
         }
 
+    def test_serialization_includes_cancel_on_kill(self):
+        """Test that cancel_on_kill=True is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is True
+
+    def test_serialization_cancel_on_kill_false(self):
+        """Test that cancel_on_kill=False is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=False,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is False
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_cancels_job_on_killed_when_safe(self, 
mock_safe_to_cancel, mock_async_wait):
+        """
+        Test that EmrContainerTrigger cancels the job when task is killed
+        and safe_to_cancel returns True.
+        """
+        mock_safe_to_cancel.return_value = True
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+        mock_hook.stop_query.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()

Review Comment:
   These tests create `mock.MagicMock()` objects without a `spec`/`autospec`, 
which can hide attribute/typo bugs (e.g., `stop_query`/`get_async_conn` 
naming). Prefer `mock.create_autospec(EmrContainerHook, instance=True)` (or 
`MagicMock(spec=...)`) and `autospec=True` patches so the test fails if the 
production API changes.
   ```suggestion
           mock_hook = mock.MagicMock(spec=["get_waiter", "stop_query", 
"get_async_conn"])
           mock_hook.get_waiter.return_value = mock.MagicMock(spec=[])
           mock_hook.stop_query.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
   
           mock_client = object()
           mock_async_cm = mock.MagicMock(spec=["__aenter__", "__aexit__"])
   ```



##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
             "waiter_delay": 30,
             "waiter_max_attempts": sys.maxsize,
             "aws_conn_id": "aws_default",
+            "cancel_on_kill": True,
         }
 
+    def test_serialization_includes_cancel_on_kill(self):
+        """Test that cancel_on_kill=True is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is True
+
+    def test_serialization_cancel_on_kill_false(self):
+        """Test that cancel_on_kill=False is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=False,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is False
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_cancels_job_on_killed_when_safe(self, 
mock_safe_to_cancel, mock_async_wait):
+        """
+        Test that EmrContainerTrigger cancels the job when task is killed
+        and safe_to_cancel returns True.
+        """
+        mock_safe_to_cancel.return_value = True
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+        mock_hook.stop_query.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+        with mock.patch.object(trigger, "hook", return_value=mock_hook):
+            with pytest.raises(asyncio.CancelledError):
+                async for _ in trigger.run():
+                    pass
+
+        mock_hook.stop_query.assert_called_once_with("test_job")
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_no_cancel_when_unsafe(self, mock_safe_to_cancel, 
mock_async_wait):
+        """
+        Test that EmrContainerTrigger does NOT cancel the job when
+        safe_to_cancel returns False (e.g., triggerer shutdown).
+        """
+        mock_safe_to_cancel.return_value = False
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+        with mock.patch.object(trigger, "hook", return_value=mock_hook):
+            with pytest.raises(asyncio.CancelledError):
+                async for _ in trigger.run():
+                    pass
+
+        mock_hook.stop_query.assert_not_called()
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_no_cancel_when_disabled(self, mock_safe_to_cancel, 
mock_async_wait):
+        """
+        Test that EmrContainerTrigger does NOT cancel the job when
+        cancel_on_kill=False.
+        """
+        mock_safe_to_cancel.return_value = True
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=False,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+

Review Comment:
   `mock.MagicMock()` objects here are not specced/autospecced. Consider using 
`autospec=True`/`create_autospec` for the hook and context manager to ensure 
the test catches API mismatches (e.g. signature/name changes on 
`get_async_conn`, `get_waiter`).



##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
             "waiter_delay": 30,
             "waiter_max_attempts": sys.maxsize,
             "aws_conn_id": "aws_default",
+            "cancel_on_kill": True,
         }
 
+    def test_serialization_includes_cancel_on_kill(self):
+        """Test that cancel_on_kill=True is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is True
+
+    def test_serialization_cancel_on_kill_false(self):
+        """Test that cancel_on_kill=False is correctly serialized."""
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=False,
+        )
+        classpath, kwargs = trigger.serialize()
+        assert classpath == 
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+        assert kwargs["cancel_on_kill"] is False
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_cancels_job_on_killed_when_safe(self, 
mock_safe_to_cancel, mock_async_wait):
+        """
+        Test that EmrContainerTrigger cancels the job when task is killed
+        and safe_to_cancel returns True.
+        """
+        mock_safe_to_cancel.return_value = True
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+        mock_hook.stop_query.return_value = {"ResponseMetadata": 
{"HTTPStatusCode": 200}}
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+        with mock.patch.object(trigger, "hook", return_value=mock_hook):
+            with pytest.raises(asyncio.CancelledError):
+                async for _ in trigger.run():
+                    pass
+
+        mock_hook.stop_query.assert_called_once_with("test_job")
+
+    @pytest.mark.asyncio
+    @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+    
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+    async def test_run_no_cancel_when_unsafe(self, mock_safe_to_cancel, 
mock_async_wait):
+        """
+        Test that EmrContainerTrigger does NOT cancel the job when
+        safe_to_cancel returns False (e.g., triggerer shutdown).
+        """
+        mock_safe_to_cancel.return_value = False
+        mock_async_wait.side_effect = asyncio.CancelledError()
+
+        trigger = EmrContainerTrigger(
+            virtual_cluster_id="test_cluster",
+            job_id="test_job",
+            waiter_delay=30,
+            waiter_max_attempts=60,
+            aws_conn_id="aws_default",
+            cancel_on_kill=True,
+        )
+
+        mock_hook = mock.MagicMock()
+        mock_hook.get_waiter.return_value = mock.MagicMock()
+
+        mock_client = mock.MagicMock()
+        mock_async_cm = mock.MagicMock()
+        mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+        mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+        mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+

Review Comment:
   `mock.MagicMock()` is used here without `spec`/`autospec`, which makes the 
test permissive to wrong method names and missing attributes. Using an 
autospecced `EmrContainerHook`/async context manager mock would keep this test 
aligned with the real hook interface.



##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,9 +197,133 @@ def __init__(
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.cancel_on_kill = cancel_on_kill
 
     def hook(self) -> AwsGenericHook:
-        return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+        return EmrContainerHook(aws_conn_id=self.aws_conn_id, 
virtual_cluster_id=self.virtual_cluster_id)
+
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for the current trigger (Airflow 2.x 
compatibility)."""
+            from sqlalchemy import select
+
+            query = select(TaskInstance).where(
+                TaskInstance.dag_id == self.task_instance.dag_id,
+                TaskInstance.task_id == self.task_instance.task_id,
+                TaskInstance.run_id == self.task_instance.run_id,
+                TaskInstance.map_index == self.task_instance.map_index,
+            )
+            task_instance = session.scalars(query).one_or_none()
+            if task_instance is None:
+                raise ValueError(
+                    f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                    f"task_id: {self.task_instance.task_id}, "
+                    f"run_id: {self.task_instance.run_id} and "
+                    f"map_index: {self.task_instance.map_index} is not found"
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        """Get the current state of the task instance (Airflow 3.x)."""
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,

Review Comment:
   `get_task_state()` imports `RuntimeTaskInstance` inside the function. If 
this import is only needed for Airflow 3, prefer a module-level conditional 
import (e.g. under `if AIRFLOW_V_3_0_PLUS:`) so the import happens once and 
avoids function-scope imports in hot paths.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to