aurangzaib048 commented on code in PR #64770:
URL: https://github.com/apache/airflow/pull/64770#discussion_r3067055882
##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,10 +197,132 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+ self.virtual_cluster_id = virtual_cluster_id
+ self.job_id = job_id
+ self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+ from sqlalchemy import select
+
+ query = select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = session.scalars(query).one_or_none()
+ if task_instance is None:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ )
+ return task_instance
+
+ async def get_task_state(self):
+ """Get the current state of the task instance (Airflow 3.x)."""
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except (KeyError, TypeError) as e:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ ) from e
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the EMR container job.
+
+ Returns True if task is NOT DEFERRED (user-initiated cancellation).
+ Returns False if task is DEFERRED (triggerer restart - don't cancel
job).
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
+ task_state = task_instance.state
+ return task_state != TaskInstanceState.DEFERRED
+
+ async def run(self) -> AsyncIterator[TriggerEvent]:
+ """
+ Run the trigger and wait for the job to complete.
+
+ If the task is cancelled while waiting, attempt to cancel the EMR
container job
+ if cancel_on_kill is enabled and it's safe to do so.
+ """
+ hook: EmrContainerHook = self.hook() # type: ignore[assignment]
+ try:
+ async with await hook.get_async_conn() as client:
+ waiter = hook.get_waiter(
+ self.waiter_name,
+ deferrable=True,
+ client=client,
+ config_overrides=self.waiter_config_overrides,
+ )
+ await async_wait(
+ waiter,
+ self.waiter_delay,
+ self.attempts,
+ self.waiter_args,
+ self.failure_message,
+ self.status_message,
+ self.status_queries,
+ )
+ yield TriggerEvent({"status": "success", self.return_key:
self.return_value})
+ except asyncio.CancelledError:
+ try:
+ if self.job_id and self.cancel_on_kill and await
self.safe_to_cancel():
+ self.log.info(
+ "Task was cancelled. Cancelling EMR container job. "
+ "Virtual Cluster ID: %s, Job ID: %s",
+ self.virtual_cluster_id,
+ self.job_id,
+ )
+ try:
+ hook.stop_query(self.job_id)
+ self.log.info("EMR container job %s cancelled
successfully.", self.job_id)
+ except Exception:
+ self.log.exception(
+ "Failed to cancel EMR container job %s. The job
may still be running.",
+ self.job_id,
+ )
+ else:
+ self.log.info(
+ "Trigger may have shutdown or cancel_on_kill is
disabled. "
+ "Skipping job cancellation. Virtual Cluster ID: %s,
Job ID: %s",
+ self.virtual_cluster_id,
+ self.job_id,
+ )
+ except asyncio.CancelledError:
+ raise
+ except Exception:
+ self.log.exception(
+ "Error during cancellation check for EMR container job %s.
The job may still be running.",
+ self.job_id,
+ )
+ raise
+ except AirflowException as e:
+ yield TriggerEvent({"status": "error", "message": str(e),
self.return_key: self.return_value})
Review Comment:
Good catch. Added `except Exception` fallback after `AirflowException` to
match the base class behavior — both now yield an error TriggerEvent.
##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,10 +197,132 @@ def __init__(
waiter_max_attempts=waiter_max_attempts,
aws_conn_id=aws_conn_id,
)
+ self.virtual_cluster_id = virtual_cluster_id
+ self.job_id = job_id
+ self.cancel_on_kill = cancel_on_kill
def hook(self) -> AwsGenericHook:
return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for the current trigger (Airflow 2.x
compatibility)."""
+ from sqlalchemy import select
+
+ query = select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ task_instance = session.scalars(query).one_or_none()
+ if task_instance is None:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ )
+ return task_instance
+
+ async def get_task_state(self):
+ """Get the current state of the task instance (Airflow 3.x)."""
+ from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ task_state =
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except (KeyError, TypeError) as e:
+ raise ValueError(
+ f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+ f"task_id: {self.task_instance.task_id}, "
+ f"run_id: {self.task_instance.run_id} and "
+ f"map_index: {self.task_instance.map_index} is not found"
+ ) from e
+ return task_state
+
+ async def safe_to_cancel(self) -> bool:
+ """
+ Whether it is safe to cancel the EMR container job.
+
+ Returns True if task is NOT DEFERRED (user-initiated cancellation).
+ Returns False if task is DEFERRED (triggerer restart - don't cancel
job).
+ """
+ if AIRFLOW_V_3_0_PLUS:
+ task_state = await self.get_task_state()
+ else:
+ task_instance = self.get_task_instance() # type: ignore[call-arg]
Review Comment:
Fixed. Wrapped `get_task_instance()` with `sync_to_async()` and
`stop_query()` with `await sync_to_async(hook.stop_query)()` to avoid blocking
the triggerer event loop.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]