Copilot commented on code in PR #64770:
URL: https://github.com/apache/airflow/pull/64770#discussion_r3067509583


##########
providers/amazon/src/airflow/providers/amazon/aws/triggers/emr.py:
##########
@@ -190,9 +198,129 @@ def __init__(
             waiter_max_attempts=waiter_max_attempts,
             aws_conn_id=aws_conn_id,
         )
+        self.virtual_cluster_id = virtual_cluster_id
+        self.job_id = job_id
+        self.cancel_on_kill = cancel_on_kill
 
     def hook(self) -> AwsGenericHook:
-        return EmrContainerHook(aws_conn_id=self.aws_conn_id)
+        return EmrContainerHook(aws_conn_id=self.aws_conn_id, 
virtual_cluster_id=self.virtual_cluster_id)
+
+    if not AIRFLOW_V_3_0_PLUS:
+
+        @provide_session
+        def get_task_instance(self, session: Session) -> TaskInstance:
+            """Get the task instance for the current trigger (Airflow 2.x 
compatibility)."""
+            query = select(TaskInstance).where(
+                TaskInstance.dag_id == self.task_instance.dag_id,
+                TaskInstance.task_id == self.task_instance.task_id,
+                TaskInstance.run_id == self.task_instance.run_id,
+                TaskInstance.map_index == self.task_instance.map_index,
+            )
+            task_instance = session.scalars(query).one_or_none()
+            if task_instance is None:
+                raise ValueError(
+                    f"TaskInstance with dag_id: {self.task_instance.dag_id}, "
+                    f"task_id: {self.task_instance.task_id}, "
+                    f"run_id: {self.task_instance.run_id} and "
+                    f"map_index: {self.task_instance.map_index} is not found"
+                )
+            return task_instance
+
+    async def get_task_state(self):
+        """Get the current state of the task instance (Airflow 3.x)."""
+        from airflow.sdk.execution_time.task_runner import RuntimeTaskInstance
+
+        task_states_response = await 
sync_to_async(RuntimeTaskInstance.get_task_states)(
+            dag_id=self.task_instance.dag_id,
+            task_ids=[self.task_instance.task_id],
+            run_ids=[self.task_instance.run_id],
+            map_index=self.task_instance.map_index,
+        )
+        try:
+            task_state = 
task_states_response[self.task_instance.run_id][self.task_instance.task_id]

Review Comment:
   In Airflow 3.x, `RuntimeTaskInstance.get_task_states(..., map_index=...)` 
returns a mapping where mapped task instances use keys like 
`"{task_id}_{map_index}"` (see execution API 
`task_instances.get_task_instance_states`). This code always indexes with just 
`task_id`, so mapped tasks will raise `KeyError` and be treated as "not found", 
which prevents cancellation for mapped deferrable EMR container jobs. Update 
the lookup key to include `map_index` when `self.task_instance.map_index >= 0` 
(or adjust the API call/response handling accordingly).
   ```suggestion
           task_instance_key = self.task_instance.task_id
           if self.task_instance.map_index >= 0:
               task_instance_key = 
f"{self.task_instance.task_id}_{self.task_instance.map_index}"
           try:
               task_state = 
task_states_response[self.task_instance.run_id][task_instance_key]
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to