Copilot commented on code in PR #64770:
URL: https://github.com/apache/airflow/pull/64770#discussion_r3067080788
##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,9 +197,133 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+ self.virtual_cluster_id = virtual_cluster_id
+ self.job_id = job_id
+ self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
- return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+ return EmrContainerHook(aws_conn_id=self.aws_conn_id,
virtual_cluster_id=self.virtual_cluster_id)
+
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+ from sqlalchemy import select
+
+ query = select(TaskInstance).where(
Review Comment:
Imports are being done inside method bodies (`from sqlalchemy import
select`). Per Airflow codebase guidelines, imports should be at module scope
unless there is a strong circular/lazy-load reason. Consider moving this import
to the existing `if not AIRFLOW_V_3_0_PLUS:` module-level block to avoid
repeated imports and keep import behavior consistent.
##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
"waiter_delay": 30,
"waiter_max_attempts": sys.maxsize,
"aws_conn_id": "aws_default",
+ "cancel_on_kill": True,
}
+ def test_serialization_includes_cancel_on_kill(self):
+ """Test that cancel_on_kill=True is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is True
+
+ def test_serialization_cancel_on_kill_false(self):
+ """Test that cancel_on_kill=False is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is False
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_cancels_job_on_killed_when_safe(self,
mock_safe_to_cancel, mock_async_wait):
+ """
+ Test that EmrContainerTrigger cancels the job when task is killed
+ and safe_to_cancel returns True.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+ mock_hook.stop_query.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.stop_query.assert_called_once_with("test_job")
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_no_cancel_when_unsafe(self, mock_safe_to_cancel,
mock_async_wait):
+ """
+ Test that EmrContainerTrigger does NOT cancel the job when
+ safe_to_cancel returns False (e.g., triggerer shutdown).
+ """
+ mock_safe_to_cancel.return_value = False
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.stop_query.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_no_cancel_when_disabled(self, mock_safe_to_cancel,
mock_async_wait):
+ """
+ Test that EmrContainerTrigger does NOT cancel the job when
+ cancel_on_kill=False.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.stop_query.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+ async def test_run_yields_error_on_airflow_exception(self,
mock_async_wait):
+ """Test that an AirflowException yields an error TriggerEvent."""
+ mock_async_wait.side_effect = AirflowException("Something went wrong")
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
Review Comment:
This new test also relies on unspecced `MagicMock` instances for the
hook/client/context manager. Using autospecced mocks would better validate the
`EmrContainerHook` API and avoid silent passes if production method
names/signatures change.
##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
"waiter_delay": 30,
"waiter_max_attempts": sys.maxsize,
"aws_conn_id": "aws_default",
+ "cancel_on_kill": True,
}
+ def test_serialization_includes_cancel_on_kill(self):
+ """Test that cancel_on_kill=True is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is True
+
+ def test_serialization_cancel_on_kill_false(self):
+ """Test that cancel_on_kill=False is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is False
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_cancels_job_on_killed_when_safe(self,
mock_safe_to_cancel, mock_async_wait):
+ """
+ Test that EmrContainerTrigger cancels the job when task is killed
+ and safe_to_cancel returns True.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+ mock_hook.stop_query.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
Review Comment:
These tests create `mock.MagicMock()` objects without a `spec`/`autospec`,
which can hide attribute/typo bugs (e.g., `stop_query`/`get_async_conn`
naming). Prefer `mock.create_autospec(EmrContainerHook, instance=True)` (or
`MagicMock(spec=...)`) and `autospec=True` patches so the test fails if the
production API changes.
```suggestion
mock_hook = mock.MagicMock(spec=["get_waiter", "stop_query",
"get_async_conn"])
mock_hook.get_waiter.return_value = mock.MagicMock(spec=[])
mock_hook.stop_query.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
mock_client = object()
mock_async_cm = mock.MagicMock(spec=["__aenter__", "__aexit__"])
```
##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
"waiter_delay": 30,
"waiter_max_attempts": sys.maxsize,
"aws_conn_id": "aws_default",
+ "cancel_on_kill": True,
}
+ def test_serialization_includes_cancel_on_kill(self):
+ """Test that cancel_on_kill=True is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is True
+
+ def test_serialization_cancel_on_kill_false(self):
+ """Test that cancel_on_kill=False is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is False
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_cancels_job_on_killed_when_safe(self,
mock_safe_to_cancel, mock_async_wait):
+ """
+ Test that EmrContainerTrigger cancels the job when task is killed
+ and safe_to_cancel returns True.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+ mock_hook.stop_query.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.stop_query.assert_called_once_with("test_job")
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_no_cancel_when_unsafe(self, mock_safe_to_cancel,
mock_async_wait):
+ """
+ Test that EmrContainerTrigger does NOT cancel the job when
+ safe_to_cancel returns False (e.g., triggerer shutdown).
+ """
+ mock_safe_to_cancel.return_value = False
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.stop_query.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_no_cancel_when_disabled(self, mock_safe_to_cancel,
mock_async_wait):
+ """
+ Test that EmrContainerTrigger does NOT cancel the job when
+ cancel_on_kill=False.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
Review Comment:
`mock.MagicMock()` objects here are not specced/autospecced. Consider using
`autospec=True`/`create_autospec` for the hook and context manager to ensure
the test catches API mismatches (e.g. signature/name changes on
`get_async_conn`, `get_waiter`).
##########
providers/amazon/tests/unit/amazon/aws/triggers/test_emr.py:
##########
@@ -152,8 +154,179 @@ def test_serialization_default_max_attempts(self):
"waiter_delay": 30,
"waiter_max_attempts": sys.maxsize,
"aws_conn_id": "aws_default",
+ "cancel_on_kill": True,
}
+ def test_serialization_includes_cancel_on_kill(self):
+ """Test that cancel_on_kill=True is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is True
+
+ def test_serialization_cancel_on_kill_false(self):
+ """Test that cancel_on_kill=False is correctly serialized."""
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=False,
+ )
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger"
+ assert kwargs["cancel_on_kill"] is False
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_cancels_job_on_killed_when_safe(self,
mock_safe_to_cancel, mock_async_wait):
+ """
+ Test that EmrContainerTrigger cancels the job when task is killed
+ and safe_to_cancel returns True.
+ """
+ mock_safe_to_cancel.return_value = True
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+ mock_hook.stop_query.return_value = {"ResponseMetadata":
{"HTTPStatusCode": 200}}
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
+ with mock.patch.object(trigger, "hook", return_value=mock_hook):
+ with pytest.raises(asyncio.CancelledError):
+ async for _ in trigger.run():
+ pass
+
+ mock_hook.stop_query.assert_called_once_with("test_job")
+
+ @pytest.mark.asyncio
+ @mock.patch("airflow.providers.amazon.aws.triggers.emr.async_wait")
+
@mock.patch("airflow.providers.amazon.aws.triggers.emr.EmrContainerTrigger.safe_to_cancel")
+ async def test_run_no_cancel_when_unsafe(self, mock_safe_to_cancel,
mock_async_wait):
+ """
+ Test that EmrContainerTrigger does NOT cancel the job when
+ safe_to_cancel returns False (e.g., triggerer shutdown).
+ """
+ mock_safe_to_cancel.return_value = False
+ mock_async_wait.side_effect = asyncio.CancelledError()
+
+ trigger = EmrContainerTrigger(
+ virtual_cluster_id="test_cluster",
+ job_id="test_job",
+ waiter_delay=30,
+ waiter_max_attempts=60,
+ aws_conn_id="aws_default",
+ cancel_on_kill=True,
+ )
+
+ mock_hook = mock.MagicMock()
+ mock_hook.get_waiter.return_value = mock.MagicMock()
+
+ mock_client = mock.MagicMock()
+ mock_async_cm = mock.MagicMock()
+ mock_async_cm.__aenter__ = mock.AsyncMock(return_value=mock_client)
+ mock_async_cm.__aexit__ = mock.AsyncMock(return_value=None)
+ mock_hook.get_async_conn = mock.AsyncMock(return_value=mock_async_cm)
+
Review Comment:
`mock.MagicMock()` is used here without `spec`/`autospec`, which makes the
test permissive to wrong method names and missing attributes. Using an
autospecced `EmrContainerHook`/async context manager mock would keep this test
aligned with the real hook interface.
##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,9 +197,133 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+ self.virtual_cluster_id = virtual_cluster_id
+ self.job_id = job_id
+ self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
- return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+ return EmrContainerHook(aws_conn_id=self.aws_conn_id,
virtual_cluster_id=self.virtual_cluster_id)
+
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+ from sqlalchemy import select
+
+ query = select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = session.scalars(query).one_or_none()
+ if task_instance is None:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ )
+ return task_instance
+
+ async def get_task_state(self):
+ """Get the current state of the task instance (Airflow 3.x)."""
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
Review Comment:
`get_task_state()` imports `RuntimeTaskInstance` inside the function. If
this import is only needed for Airflow 3, prefer a module-level conditional
import (e.g. under `if AIRFLOW_V_3_0_PLUS:`) so the import happens once and
avoids function-scope imports in hot paths.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]