Copilot commented on code in PR #62240: URL: https://github.com/apache/airflow/pull/62240#discussion_r3066479034
########## providers/amazon/src/airflow/providers/amazon/aws/operators/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,167 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains the Amazon SageMaker Unified Studio Notebook operator. + +This operator supports asynchronous notebook execution in SageMaker Unified +Studio. +""" + +from __future__ import annotations + +from typing import TYPE_CHECKING, Any + +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, +) +from airflow.providers.amazon.aws.links.sagemaker_unified_studio import ( + SageMakerUnifiedStudioLink, +) +from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator +from airflow.providers.amazon.aws.triggers.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookTrigger, +) +from airflow.providers.amazon.aws.utils import validate_execute_complete_event +from airflow.providers.common.compat.sdk import conf + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class SageMakerUnifiedStudioNotebookOperator(AwsBaseOperator[SageMakerUnifiedStudioNotebookHook]): + """ + Execute a notebook in SageMaker Unified Studio. + + This operator calls the DataZone StartNotebookRun API to kick off + headless notebook execution. When not configured otherwise, polls + the GetNotebookRun API until the run reaches a terminal state. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.operators.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookOperator, + ) + + notebook_operator = SageMakerUnifiedStudioNotebookOperator( + task_id="run_notebook", + notebook_identifier="nb-1234567890", + domain_identifier="dzd_example", + owning_project_identifier="proj_example", + notebook_parameters={"param1": "value1"}, + compute_configuration={"instance_type": "ml.m5.large"}, + timeout_configuration={"run_timeout_in_minutes": 1440}, + ) + + :param task_id: A unique, meaningful id for the task. + :param notebook_identifier: The ID of the notebook to execute. + :param domain_identifier: The ID of the SageMaker Unified Studio domain containing the notebook. + :param owning_project_identifier: The ID of the SageMaker Unified Studio project containing the notebook. + :param client_token: Optional idempotency token. Auto-generated if not provided. + :param notebook_parameters: Optional dict of parameters to pass to the notebook. + :param compute_configuration: Optional compute config. + Example: {"instance_type": "ml.m5.large"} + :param timeout_configuration: Optional timeout settings. + Example: {"run_timeout_in_minutes": 1440} + :param wait_for_completion: If True, wait for the notebook run to finish before + completing the task. If False, the operator returns immediately after starting + the run. (default: True) + :param waiter_delay: Interval in seconds to poll the notebook run status (default: 10). + :param deferrable: If True, the operator will defer polling to the trigger, + freeing up the worker slot while waiting. (default: False) + + .. seealso:: + For more information on how to use this operator, take a look at the guide: + :ref:`howto/operator:SageMakerUnifiedStudioNotebookOperator` + """ + + operator_extra_links = (SageMakerUnifiedStudioLink(),) + aws_hook_class = SageMakerUnifiedStudioNotebookHook + + def __init__( + self, + *, + notebook_identifier: str, + domain_identifier: str, + owning_project_identifier: str, + client_token: str | None = None, + notebook_parameters: dict | None = None, + compute_configuration: dict | None = None, + timeout_configuration: dict | None = None, + wait_for_completion: bool = True, + waiter_delay: int = 10, + deferrable: bool = conf.getboolean("operators", "default_deferrable", fallback=False), + **kwargs, + ): + super().__init__(**kwargs) + self.notebook_identifier = notebook_identifier + self.domain_identifier = domain_identifier + self.owning_project_identifier = owning_project_identifier + self.client_token = client_token + self.notebook_parameters = notebook_parameters + self.compute_configuration = compute_configuration + self.timeout_configuration = timeout_configuration + self.wait_for_completion = wait_for_completion + self.waiter_delay = waiter_delay + self.deferrable = deferrable + + def execute(self, context: Context): + workflow_name = context["dag"].dag_id # Workflow name is the same as the dag_id + response = self.hook.start_notebook_run( + notebook_identifier=self.notebook_identifier, + domain_identifier=self.domain_identifier, + owning_project_identifier=self.owning_project_identifier, + client_token=self.client_token, + notebook_parameters=self.notebook_parameters, + compute_configuration=self.compute_configuration, + timeout_configuration=self.timeout_configuration, + workflow_name=workflow_name, + ) + notebook_run_id = response["notebook_run_id"] Review Comment: The operator expects `start_notebook_run()` to return a `notebook_run_id` key, but the hook currently returns the raw boto response (tests show `{\"notebookRunId\": ...}`). This will raise a `KeyError` at runtime unless the hook normalizes the response (recommended) or the operator reads the API’s actual key. ```suggestion notebook_run_id = response.get("notebook_run_id") or response.get("notebookRunId") if notebook_run_id is None: raise KeyError("start_notebook_run response did not contain 'notebook_run_id' or 'notebookRunId'") ``` ########## providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook Run hook.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +TWELVE_HOURS_IN_MINUTES = 12 * 60 + + +class SageMakerUnifiedStudioNotebookHook(AwsBaseHook): + """ + Interact with Sagemaker Unified Studio Workflows for asynchronous notebook execution. + + This hook provides a wrapper around the DataZone StartNotebookRun / GetNotebookRun APIs. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, + ) + + hook = SageMakerUnifiedStudioNotebookHook(aws_conn_id="my_aws_conn") + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args: Any, **kwargs: Any): + kwargs.setdefault("client_type", "datazone") + super().__init__(*args, **kwargs) + + def _validate_api_availability(self) -> None: + """ + Verify that the NotebookRun APIs are available in the installed boto3/botocore version. + + :raises RuntimeError: If the required APIs are not available. + """ + required_methods = ("start_notebook_run", "get_notebook_run") + for method_name in required_methods: + if not hasattr(self.conn, method_name): + raise RuntimeError( + f"The '{method_name}' API is not available in the installed boto3/botocore version. " + "Please upgrade boto3/botocore to a version that supports the DataZone " + "NotebookRun APIs." + ) Review Comment: `_validate_api_availability()` is never called, so the intended fast-fail behavior won’t trigger when an older boto3/botocore is installed. Call this method from `__init__`, `start_notebook_run()`, and/or `get_notebook_run()` (single call per hook instance is typically enough). ########## providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook Run hook.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +TWELVE_HOURS_IN_MINUTES = 12 * 60 + + +class SageMakerUnifiedStudioNotebookHook(AwsBaseHook): + """ + Interact with Sagemaker Unified Studio Workflows for asynchronous notebook execution. + + This hook provides a wrapper around the DataZone StartNotebookRun / GetNotebookRun APIs. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, + ) + + hook = SageMakerUnifiedStudioNotebookHook(aws_conn_id="my_aws_conn") + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args: Any, **kwargs: Any): + kwargs.setdefault("client_type", "datazone") + super().__init__(*args, **kwargs) + + def _validate_api_availability(self) -> None: + """ + Verify that the NotebookRun APIs are available in the installed boto3/botocore version. + + :raises RuntimeError: If the required APIs are not available. + """ + required_methods = ("start_notebook_run", "get_notebook_run") + for method_name in required_methods: + if not hasattr(self.conn, method_name): + raise RuntimeError( + f"The '{method_name}' API is not available in the installed boto3/botocore version. " + "Please upgrade boto3/botocore to a version that supports the DataZone " + "NotebookRun APIs." + ) + + def start_notebook_run( + self, + notebook_identifier: str, + domain_identifier: str, + owning_project_identifier: str, + client_token: str | None = None, + notebook_parameters: dict | None = None, + compute_configuration: dict | None = None, + timeout_configuration: dict | None = None, + workflow_name: str | None = None, + ) -> dict: + """ + Start an asynchronous notebook run via the DataZone StartNotebookRun API. + + :param notebook_identifier: The ID of the notebook to execute. + :param domain_identifier: The ID of the DataZone domain containing the notebook. + :param owning_project_identifier: The ID of the DataZone project containing the notebook. + :param client_token: Idempotency token. Auto-generated if not provided. + :param notebook_parameters: Parameters to pass to the notebook. + :param compute_configuration: Compute config (e.g. instance_type). + :param timeout_configuration: Timeout settings (run_timeout_in_minutes). + :param workflow_name: Name of the workflow (DAG) that triggered this run. + :return: The StartNotebookRun API response dict. + """ + params: dict = { + "domain_identifier": domain_identifier, + "owning_project_identifier": owning_project_identifier, + "notebook_identifier": notebook_identifier, + "client_token": client_token or str(uuid.uuid4()), + } + + if notebook_parameters: + params["parameters"] = {"notebook_parameters": notebook_parameters} + if compute_configuration: + params["compute_configuration"] = compute_configuration + if timeout_configuration: + params["timeout_configuration"] = timeout_configuration + if workflow_name: + params["trigger_source"] = {"type": "workflow", "workflow_name": workflow_name} + + self.log.info( + "Starting notebook run for notebook %s in domain %s", notebook_identifier, domain_identifier + ) + return self.conn.start_notebook_run(**params) + + def get_notebook_run(self, notebook_run_id: str, domain_identifier: str) -> dict: + """ + Get the status of a notebook run via the DataZone GetNotebookRun API. + + :param notebook_run_id: The ID of the notebook run. + :param domain_identifier: The ID of the DataZone domain. + :return: The GetNotebookRun API response dict. + """ + return self.conn.get_notebook_run( + domain_identifier=domain_identifier, + identifier=notebook_run_id, + ) + + def wait_for_notebook_run( + self, + notebook_run_id: str, + domain_identifier: str, + waiter_delay: int = 10, + timeout_configuration: dict | None = None, + ) -> dict: + """ + Poll GetNotebookRun until the run reaches a terminal state. + + :param notebook_run_id: The ID of the notebook run to monitor. + :param domain_identifier: The ID of the DataZone domain. + :param waiter_delay: Interval in seconds to poll the notebook run status. + :param timeout_configuration: Timeout settings for the notebook execution. + When provided, the maximum number of poll attempts is derived from + ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours. + :return: A dict with Status and NotebookRunId on success. + :raises RuntimeError: If the run fails or times out. + """ + run_timeout = (timeout_configuration or {}).get("run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES) + waiter_max_attempts = int(run_timeout * 60 / waiter_delay) + + for _attempt in range(waiter_max_attempts): + time.sleep(waiter_delay) + response = self.get_notebook_run(notebook_run_id, domain_identifier=domain_identifier) + status = response.get("status", "") + error_message = response.get("errorMessage", "") + + ret = self._handle_state(notebook_run_id, status, error_message, waiter_delay) + if ret: + return ret + + error_message = "Execution timed out" + self.log.error("Notebook run %s failed with error: %s", notebook_run_id, error_message) + raise RuntimeError(error_message) + + def _handle_state( + self, notebook_run_id: str, state: str, error_message: str, waiter_delay: int = 10 + ) -> dict | None: + """ + Evaluate the current notebook run state and return or raise accordingly. + + :param notebook_run_id: The ID of the notebook run. + :param state: The current state string. + :param error_message: Error message from the API response, if any. + :param waiter_delay: Interval in seconds between polls (for logging). + :return: A dict with Status and NotebookRunId on success, None if still in progress. + :raises RuntimeError: If the run has failed. + """ + in_progress_states = {"QUEUED", "STARTING", "RUNNING", "STOPPING"} + finished_states = {"SUCCEEDED", "STOPPED"} + failure_states = {"FAILED"} Review Comment: `STOPPED` is treated as a success here, but the new botocore waiter marks `STOPPED` as `failure` (and the sensor raises for STOPPED). This creates inconsistent outcomes: non-deferrable execution may succeed while deferrable execution fails for the same terminal state. Align STOPPED handling across hook/sensor/waiter (most consistent is to treat STOPPED as a failure/terminal-non-success everywhere). ```suggestion finished_states = {"SUCCEEDED"} failure_states = {"FAILED", "STOPPED"} ``` ########## providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook Run hook.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +TWELVE_HOURS_IN_MINUTES = 12 * 60 + + +class SageMakerUnifiedStudioNotebookHook(AwsBaseHook): + """ + Interact with Sagemaker Unified Studio Workflows for asynchronous notebook execution. + + This hook provides a wrapper around the DataZone StartNotebookRun / GetNotebookRun APIs. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, + ) + + hook = SageMakerUnifiedStudioNotebookHook(aws_conn_id="my_aws_conn") + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args: Any, **kwargs: Any): + kwargs.setdefault("client_type", "datazone") + super().__init__(*args, **kwargs) + + def _validate_api_availability(self) -> None: + """ + Verify that the NotebookRun APIs are available in the installed boto3/botocore version. + + :raises RuntimeError: If the required APIs are not available. + """ + required_methods = ("start_notebook_run", "get_notebook_run") + for method_name in required_methods: + if not hasattr(self.conn, method_name): + raise RuntimeError( + f"The '{method_name}' API is not available in the installed boto3/botocore version. " + "Please upgrade boto3/botocore to a version that supports the DataZone " + "NotebookRun APIs." + ) + + def start_notebook_run( + self, + notebook_identifier: str, + domain_identifier: str, + owning_project_identifier: str, + client_token: str | None = None, + notebook_parameters: dict | None = None, + compute_configuration: dict | None = None, + timeout_configuration: dict | None = None, + workflow_name: str | None = None, + ) -> dict: + """ + Start an asynchronous notebook run via the DataZone StartNotebookRun API. + + :param notebook_identifier: The ID of the notebook to execute. + :param domain_identifier: The ID of the DataZone domain containing the notebook. + :param owning_project_identifier: The ID of the DataZone project containing the notebook. + :param client_token: Idempotency token. Auto-generated if not provided. + :param notebook_parameters: Parameters to pass to the notebook. + :param compute_configuration: Compute config (e.g. instance_type). + :param timeout_configuration: Timeout settings (run_timeout_in_minutes). + :param workflow_name: Name of the workflow (DAG) that triggered this run. + :return: The StartNotebookRun API response dict. + """ + params: dict = { + "domain_identifier": domain_identifier, + "owning_project_identifier": owning_project_identifier, + "notebook_identifier": notebook_identifier, + "client_token": client_token or str(uuid.uuid4()), + } + + if notebook_parameters: + params["parameters"] = {"notebook_parameters": notebook_parameters} + if compute_configuration: + params["compute_configuration"] = compute_configuration + if timeout_configuration: + params["timeout_configuration"] = timeout_configuration + if workflow_name: + params["trigger_source"] = {"type": "workflow", "workflow_name": workflow_name} + + self.log.info( + "Starting notebook run for notebook %s in domain %s", notebook_identifier, domain_identifier + ) + return self.conn.start_notebook_run(**params) + + def get_notebook_run(self, notebook_run_id: str, domain_identifier: str) -> dict: + """ + Get the status of a notebook run via the DataZone GetNotebookRun API. + + :param notebook_run_id: The ID of the notebook run. + :param domain_identifier: The ID of the DataZone domain. + :return: The GetNotebookRun API response dict. + """ + return self.conn.get_notebook_run( + domain_identifier=domain_identifier, + identifier=notebook_run_id, + ) + + def wait_for_notebook_run( + self, + notebook_run_id: str, + domain_identifier: str, + waiter_delay: int = 10, + timeout_configuration: dict | None = None, + ) -> dict: + """ + Poll GetNotebookRun until the run reaches a terminal state. + + :param notebook_run_id: The ID of the notebook run to monitor. + :param domain_identifier: The ID of the DataZone domain. + :param waiter_delay: Interval in seconds to poll the notebook run status. + :param timeout_configuration: Timeout settings for the notebook execution. + When provided, the maximum number of poll attempts is derived from + ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours. + :return: A dict with Status and NotebookRunId on success. + :raises RuntimeError: If the run fails or times out. + """ + run_timeout = (timeout_configuration or {}).get("run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES) + waiter_max_attempts = int(run_timeout * 60 / waiter_delay) Review Comment: Same issue as the trigger: `waiter_delay == 0` will crash, and small timeouts can yield `waiter_max_attempts == 0` (leading to an immediate timeout without polling). Validate `waiter_delay > 0` and compute attempts via `ceil(...)` with a minimum of 1. ########## providers/amazon/src/airflow/providers/amazon/aws/sensors/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,102 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +This module contains the Amazon SageMaker Unified Studio Notebook Run sensor. + +This sensor polls the DataZone GetNotebookRun API until the notebook run +reaches a terminal state. +""" + +from __future__ import annotations + +from collections.abc import Sequence +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, +) +from airflow.providers.amazon.aws.sensors.base_aws import AwsBaseSensor + +if TYPE_CHECKING: + from airflow.sdk import Context + + +class SageMakerUnifiedStudioNotebookSensor(AwsBaseSensor[SageMakerUnifiedStudioNotebookHook]): + """ + Polls a SageMaker Unified Studio notebook execution until it reaches a terminal state. + + 'SUCCEEDED', 'FAILED', 'STOPPED' + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.sensors.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookSensor, + ) + + notebook_sensor = SageMakerUnifiedStudioNotebookSensor( + task_id="wait_for_notebook", + domain_identifier="dzd_example", + owning_project_identifier="proj_example", + notebook_run_id="nr-1234567890", + ) + + :param domain_identifier: The ID of the SageMaker Unified Studio domain containing the notebook. + :param owning_project_identifier: The ID of the SageMaker Unified Studio project containing the notebook. + :param notebook_run_id: The ID of the notebook run to monitor. + This is returned by the ``SageMakerUnifiedStudioNotebookOperator``. + """ + + aws_hook_class = SageMakerUnifiedStudioNotebookHook + template_fields: Sequence[str] = AwsBaseSensor.template_fields + ("notebook_run_id",) + + def __init__( + self, + *, + domain_identifier: str, + owning_project_identifier: str, + notebook_run_id: str, + **kwargs, + ): + super().__init__(**kwargs) + self.domain_identifier = domain_identifier + self.owning_project_identifier = owning_project_identifier + self.notebook_run_id = notebook_run_id + self.success_states = ["SUCCEEDED"] + self.in_progress_states = ["QUEUED", "STARTING", "RUNNING", "STOPPING"] + + # override from base sensor + def poke(self, context: Context) -> bool: + response = self.hook.get_notebook_run(self.notebook_run_id, domain_identifier=self.domain_identifier) + status = response.get("status", "") + + if status in self.success_states: + self.log.info("Exiting notebook run %s. State: %s", self.notebook_run_id, status) + return True + + if status in self.in_progress_states: + return False + + error_message = f"Exiting notebook run {self.notebook_run_id}. State: {status}" + self.log.info(error_message) + raise RuntimeError(error_message) Review Comment: Sensors/operators typically raise `AirflowException` (or another Airflow-specific exception) to signal task failure consistently. Consider raising `AirflowException` here (and in related paths) instead of `RuntimeError` so failures are categorized/handled uniformly by Airflow. ########## providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook Run hook.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +TWELVE_HOURS_IN_MINUTES = 12 * 60 + + +class SageMakerUnifiedStudioNotebookHook(AwsBaseHook): + """ + Interact with Sagemaker Unified Studio Workflows for asynchronous notebook execution. + + This hook provides a wrapper around the DataZone StartNotebookRun / GetNotebookRun APIs. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, + ) + + hook = SageMakerUnifiedStudioNotebookHook(aws_conn_id="my_aws_conn") + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args: Any, **kwargs: Any): + kwargs.setdefault("client_type", "datazone") + super().__init__(*args, **kwargs) + + def _validate_api_availability(self) -> None: + """ + Verify that the NotebookRun APIs are available in the installed boto3/botocore version. + + :raises RuntimeError: If the required APIs are not available. + """ + required_methods = ("start_notebook_run", "get_notebook_run") + for method_name in required_methods: + if not hasattr(self.conn, method_name): + raise RuntimeError( + f"The '{method_name}' API is not available in the installed boto3/botocore version. " + "Please upgrade boto3/botocore to a version that supports the DataZone " + "NotebookRun APIs." + ) + + def start_notebook_run( + self, + notebook_identifier: str, + domain_identifier: str, + owning_project_identifier: str, + client_token: str | None = None, + notebook_parameters: dict | None = None, + compute_configuration: dict | None = None, + timeout_configuration: dict | None = None, + workflow_name: str | None = None, + ) -> dict: + """ + Start an asynchronous notebook run via the DataZone StartNotebookRun API. + + :param notebook_identifier: The ID of the notebook to execute. + :param domain_identifier: The ID of the DataZone domain containing the notebook. + :param owning_project_identifier: The ID of the DataZone project containing the notebook. + :param client_token: Idempotency token. Auto-generated if not provided. + :param notebook_parameters: Parameters to pass to the notebook. + :param compute_configuration: Compute config (e.g. instance_type). + :param timeout_configuration: Timeout settings (run_timeout_in_minutes). + :param workflow_name: Name of the workflow (DAG) that triggered this run. + :return: The StartNotebookRun API response dict. + """ + params: dict = { + "domain_identifier": domain_identifier, + "owning_project_identifier": owning_project_identifier, + "notebook_identifier": notebook_identifier, + "client_token": client_token or str(uuid.uuid4()), + } + + if notebook_parameters: + params["parameters"] = {"notebook_parameters": notebook_parameters} + if compute_configuration: + params["compute_configuration"] = compute_configuration + if timeout_configuration: + params["timeout_configuration"] = timeout_configuration + if workflow_name: + params["trigger_source"] = {"type": "workflow", "workflow_name": workflow_name} + + self.log.info( + "Starting notebook run for notebook %s in domain %s", notebook_identifier, domain_identifier + ) + return self.conn.start_notebook_run(**params) + + def get_notebook_run(self, notebook_run_id: str, domain_identifier: str) -> dict: + """ + Get the status of a notebook run via the DataZone GetNotebookRun API. + + :param notebook_run_id: The ID of the notebook run. + :param domain_identifier: The ID of the DataZone domain. + :return: The GetNotebookRun API response dict. + """ + return self.conn.get_notebook_run( + domain_identifier=domain_identifier, + identifier=notebook_run_id, + ) + + def wait_for_notebook_run( + self, + notebook_run_id: str, + domain_identifier: str, + waiter_delay: int = 10, + timeout_configuration: dict | None = None, + ) -> dict: + """ + Poll GetNotebookRun until the run reaches a terminal state. + + :param notebook_run_id: The ID of the notebook run to monitor. + :param domain_identifier: The ID of the DataZone domain. + :param waiter_delay: Interval in seconds to poll the notebook run status. + :param timeout_configuration: Timeout settings for the notebook execution. + When provided, the maximum number of poll attempts is derived from + ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours. + :return: A dict with Status and NotebookRunId on success. + :raises RuntimeError: If the run fails or times out. + """ + run_timeout = (timeout_configuration or {}).get("run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES) + waiter_max_attempts = int(run_timeout * 60 / waiter_delay) + + for _attempt in range(waiter_max_attempts): + time.sleep(waiter_delay) Review Comment: The polling loop sleeps before the first `get_notebook_run()` call, adding an avoidable initial delay even when the run is already in a terminal state. Consider polling immediately on the first iteration, then sleeping between subsequent polls (this also aligns better with typical waiter semantics). ```suggestion if _attempt > 0: time.sleep(waiter_delay) ``` ########## providers/amazon/src/airflow/providers/amazon/aws/hooks/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,206 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""This module contains the Amazon SageMaker Unified Studio Notebook Run hook.""" + +from __future__ import annotations + +import time +import uuid +from typing import Any + +from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook + +TWELVE_HOURS_IN_MINUTES = 12 * 60 + + +class SageMakerUnifiedStudioNotebookHook(AwsBaseHook): + """ + Interact with Sagemaker Unified Studio Workflows for asynchronous notebook execution. + + This hook provides a wrapper around the DataZone StartNotebookRun / GetNotebookRun APIs. + + Examples: + .. code-block:: python + + from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, + ) + + hook = SageMakerUnifiedStudioNotebookHook(aws_conn_id="my_aws_conn") + + Additional arguments (such as ``aws_conn_id`` or ``region_name``) may be specified and + are passed down to the underlying AwsBaseHook. + + .. seealso:: + - :class:`airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook` + """ + + def __init__(self, *args: Any, **kwargs: Any): + kwargs.setdefault("client_type", "datazone") + super().__init__(*args, **kwargs) + + def _validate_api_availability(self) -> None: + """ + Verify that the NotebookRun APIs are available in the installed boto3/botocore version. + + :raises RuntimeError: If the required APIs are not available. + """ + required_methods = ("start_notebook_run", "get_notebook_run") + for method_name in required_methods: + if not hasattr(self.conn, method_name): + raise RuntimeError( + f"The '{method_name}' API is not available in the installed boto3/botocore version. " + "Please upgrade boto3/botocore to a version that supports the DataZone " + "NotebookRun APIs." + ) + + def start_notebook_run( + self, + notebook_identifier: str, + domain_identifier: str, + owning_project_identifier: str, + client_token: str | None = None, + notebook_parameters: dict | None = None, + compute_configuration: dict | None = None, + timeout_configuration: dict | None = None, + workflow_name: str | None = None, + ) -> dict: + """ + Start an asynchronous notebook run via the DataZone StartNotebookRun API. + + :param notebook_identifier: The ID of the notebook to execute. + :param domain_identifier: The ID of the DataZone domain containing the notebook. + :param owning_project_identifier: The ID of the DataZone project containing the notebook. + :param client_token: Idempotency token. Auto-generated if not provided. + :param notebook_parameters: Parameters to pass to the notebook. + :param compute_configuration: Compute config (e.g. instance_type). + :param timeout_configuration: Timeout settings (run_timeout_in_minutes). + :param workflow_name: Name of the workflow (DAG) that triggered this run. + :return: The StartNotebookRun API response dict. + """ + params: dict = { + "domain_identifier": domain_identifier, + "owning_project_identifier": owning_project_identifier, + "notebook_identifier": notebook_identifier, + "client_token": client_token or str(uuid.uuid4()), + } + + if notebook_parameters: + params["parameters"] = {"notebook_parameters": notebook_parameters} + if compute_configuration: + params["compute_configuration"] = compute_configuration + if timeout_configuration: + params["timeout_configuration"] = timeout_configuration + if workflow_name: + params["trigger_source"] = {"type": "workflow", "workflow_name": workflow_name} + + self.log.info( + "Starting notebook run for notebook %s in domain %s", notebook_identifier, domain_identifier + ) + return self.conn.start_notebook_run(**params) + + def get_notebook_run(self, notebook_run_id: str, domain_identifier: str) -> dict: + """ + Get the status of a notebook run via the DataZone GetNotebookRun API. + + :param notebook_run_id: The ID of the notebook run. + :param domain_identifier: The ID of the DataZone domain. + :return: The GetNotebookRun API response dict. + """ + return self.conn.get_notebook_run( + domain_identifier=domain_identifier, + identifier=notebook_run_id, + ) + + def wait_for_notebook_run( + self, + notebook_run_id: str, + domain_identifier: str, + waiter_delay: int = 10, + timeout_configuration: dict | None = None, + ) -> dict: + """ + Poll GetNotebookRun until the run reaches a terminal state. + + :param notebook_run_id: The ID of the notebook run to monitor. + :param domain_identifier: The ID of the DataZone domain. + :param waiter_delay: Interval in seconds to poll the notebook run status. + :param timeout_configuration: Timeout settings for the notebook execution. + When provided, the maximum number of poll attempts is derived from + ``run_timeout_in_minutes * 60 / waiter_delay``. Defaults to 12 hours. + :return: A dict with Status and NotebookRunId on success. Review Comment: The documented return shape doesn’t match the implementation: `_handle_state()`/`wait_for_notebook_run()` return `{\"State\": ..., \"NotebookRunId\": ...}` (not `Status`). Update the docstring to reflect the actual keys or adjust the returned dict to match the documented contract. ```suggestion :return: A dict with State and NotebookRunId on success. ``` ########## providers/amazon/src/airflow/providers/amazon/aws/triggers/sagemaker_unified_studio_notebook.py: ########## @@ -0,0 +1,101 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Trigger for monitoring SageMaker Unified Studio Notebook runs asynchronously.""" + +from __future__ import annotations + +from typing import TYPE_CHECKING + +from airflow.providers.amazon.aws.hooks.sagemaker_unified_studio_notebook import ( + SageMakerUnifiedStudioNotebookHook, +) +from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger + +if TYPE_CHECKING: + from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook + +TWELVE_HOURS_IN_MINUTES = 12 * 60 + + +class SageMakerUnifiedStudioNotebookTrigger(AwsBaseWaiterTrigger): + """ + Watches an asynchronous notebook run, triggering when it reaches a terminal state. + + Uses a custom boto waiter (``notebook_run_complete``) defined in + ``waiters/datazone.json`` to poll the DataZone ``GetNotebookRun`` API. + + :param notebook_run_id: The ID of the notebook run to monitor. + :param domain_identifier: The ID of the DataZone domain. + :param owning_project_identifier: The ID of the DataZone project. + :param waiter_delay: Interval in seconds between polls (default: 10). + :param waiter_max_attempts: Maximum number of poll attempts. + :param aws_conn_id: The Airflow connection used for AWS credentials. + :param timeout_configuration: Optional timeout settings. When provided, the maximum + number of poll attempts is derived from ``run_timeout_in_minutes * 60 / waiter_delay``. + Defaults to a 12-hour timeout when omitted. + Example: {"run_timeout_in_minutes": 720} + """ + + def __init__( + self, + notebook_run_id: str, + domain_identifier: str, + owning_project_identifier: str, + waiter_delay: int = 10, + timeout_configuration: dict | None = None, + aws_conn_id: str | None = None, + **kwargs, + ): + run_timeout = (timeout_configuration or {}).get("run_timeout_in_minutes", TWELVE_HOURS_IN_MINUTES) + waiter_max_attempts = int(run_timeout * 60 / waiter_delay) Review Comment: This calculation can produce `0` attempts when `run_timeout * 60 < waiter_delay` and will raise `ZeroDivisionError` when `waiter_delay == 0`. Consider validating `waiter_delay > 0` and computing attempts using `ceil(...)` with a minimum of 1 to avoid immediate/undefined behavior on small timeouts. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
