This is an automated email from the ASF dual-hosted git repository.
eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 004d1d3a84 Adding Amazon Neptune Hook and Operators (#37000)
004d1d3a84 is described below
commit 004d1d3a84224ac728f0bcfee68dae13522fe907
Author: ellisms <[email protected]>
AuthorDate: Fri Jan 26 16:11:04 2024 -0500
Adding Amazon Neptune Hook and Operators (#37000)
* neptune operators
* System test fix
* Update return type
Co-authored-by: Wei Lee <[email protected]>
* Update airflow/providers/amazon/aws/operators/neptune.py
Co-authored-by: Wei Lee <[email protected]>
* Update airflow/providers/amazon/aws/operators/neptune.py
Co-authored-by: Wei Lee <[email protected]>
* Update airflow/providers/amazon/aws/operators/neptune.py
Co-authored-by: Wei Lee <[email protected]>
* PR Review changes
* Review changes; fixed databrew waiter test case
* Moved cluster states to hook
* Update airflow/providers/amazon/aws/operators/neptune.py
Co-authored-by: Elad Kalif <[email protected]>
---------
Co-authored-by: Wei Lee <[email protected]>
Co-authored-by: Elad Kalif <[email protected]>
---
airflow/providers/amazon/aws/hooks/neptune.py | 85 ++++++++
airflow/providers/amazon/aws/operators/neptune.py | 218 +++++++++++++++++++++
airflow/providers/amazon/aws/triggers/neptune.py | 115 +++++++++++
airflow/providers/amazon/aws/waiters/neptune.json | 85 ++++++++
airflow/providers/amazon/provider.yaml | 15 ++
.../operators/neptune.rst | 77 ++++++++
docs/integration-logos/aws/Amazon-Neptune_64.png | Bin 0 -> 19338 bytes
tests/providers/amazon/aws/hooks/test_neptune.py | 52 +++++
.../providers/amazon/aws/operators/test_neptune.py | 152 ++++++++++++++
.../providers/amazon/aws/triggers/test_neptune.py | 82 ++++++++
tests/providers/amazon/aws/waiters/test_neptune.py | 89 +++++++++
.../system/providers/amazon/aws/example_neptune.py | 68 +++++++
12 files changed, 1038 insertions(+)
diff --git a/airflow/providers/amazon/aws/hooks/neptune.py
b/airflow/providers/amazon/aws/hooks/neptune.py
new file mode 100644
index 0000000000..a0640647e3
--- /dev/null
+++ b/airflow/providers/amazon/aws/hooks/neptune.py
@@ -0,0 +1,85 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from airflow.providers.amazon.aws.hooks.base_aws import AwsBaseHook
+
+
+class NeptuneHook(AwsBaseHook):
+ """
+ Interact with Amazon Neptune.
+
+ Additional arguments (such as ``aws_conn_id``) may be specified and
+ are passed down to the underlying AwsBaseHook.
+
+ .. seealso::
+ - :class:`~airflow.providers.amazon.aws.hooks.base_aws.AwsBaseHook`
+ """
+
+ AVAILABLE_STATES = ["available"]
+ STOPPED_STATES = ["stopped"]
+
+ def __init__(self, *args, **kwargs):
+ kwargs["client_type"] = "neptune"
+ super().__init__(*args, **kwargs)
+
+ def wait_for_cluster_availability(self, cluster_id: str, delay: int = 30,
max_attempts: int = 60) -> str:
+ """
+ Wait for Neptune cluster to start.
+
+ :param cluster_id: The ID of the cluster to wait for.
+ :param delay: Time in seconds to delay between polls.
+ :param max_attempts: Maximum number of attempts to poll for completion.
+ :return: The status of the cluster.
+ """
+ self.get_waiter("cluster_available").wait(
+ DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay,
"MaxAttempts": max_attempts}
+ )
+
+ status = self.get_cluster_status(cluster_id)
+ self.log.info("Finished waiting for cluster %s. Status is now %s",
cluster_id, status)
+
+ return status
+
+ def wait_for_cluster_stopped(self, cluster_id: str, delay: int = 30,
max_attempts: int = 60) -> str:
+ """
+ Wait for Neptune cluster to stop.
+
+ :param cluster_id: The ID of the cluster to wait for.
+ :param delay: Time in seconds to delay between polls.
+ :param max_attempts: Maximum number of attempts to poll for completion.
+ :return: The status of the cluster.
+ """
+ self.get_waiter("cluster_stopped").wait(
+ DBClusterIdentifier=cluster_id, WaiterConfig={"Delay": delay,
"MaxAttempts": max_attempts}
+ )
+
+ status = self.get_cluster_status(cluster_id)
+ self.log.info("Finished waiting for cluster %s. Status is now %s",
cluster_id, status)
+
+ return status
+
+ def get_cluster_status(self, cluster_id: str) -> str:
+ """
+ Get the status of a Neptune cluster.
+
+ :param cluster_id: The ID of the cluster to get the status of.
+ :return: The status of the cluster.
+ """
+ return
self.get_conn().describe_db_clusters(DBClusterIdentifier=cluster_id)["DBClusters"][0]["Status"]
diff --git a/airflow/providers/amazon/aws/operators/neptune.py
b/airflow/providers/amazon/aws/operators/neptune.py
new file mode 100644
index 0000000000..a55b40c378
--- /dev/null
+++ b/airflow/providers/amazon/aws/operators/neptune.py
@@ -0,0 +1,218 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING, Any, Sequence
+
+from airflow.configuration import conf
+from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
+from airflow.providers.amazon.aws.operators.base_aws import AwsBaseOperator
+from airflow.providers.amazon.aws.triggers.neptune import (
+ NeptuneClusterAvailableTrigger,
+ NeptuneClusterStoppedTrigger,
+)
+from airflow.providers.amazon.aws.utils.mixins import aws_template_fields
+
+if TYPE_CHECKING:
+ from airflow.utils.context import Context
+
+
+class NeptuneStartDbClusterOperator(AwsBaseOperator[NeptuneHook]):
+ """Starts an Amazon Neptune DB cluster.
+
+ Amazon Neptune Database is a serverless graph database designed for
superior scalability
+ and availability. Neptune Database provides built-in security, continuous
backups, and
+ integrations with other AWS services
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:NeptuneStartDbClusterOperator`
+
+ :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster
to be started.
+ :param wait_for_completion: Whether to wait for the cluster to start.
(default: True)
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to start.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param waiter_delay: Time in seconds to wait between status checks.
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ :return: dictionary with Neptune cluster id
+ """
+
+ aws_hook_class = NeptuneHook
+ template_fields: Sequence[str] = aws_template_fields("cluster_id")
+
+ def __init__(
+ self,
+ db_cluster_id: str,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.cluster_id = db_cluster_id
+ self.wait_for_completion = wait_for_completion
+ self.deferrable = deferrable
+ self.delay = waiter_delay
+ self.max_attempts = waiter_max_attempts
+
+ def execute(self, context: Context) -> dict[str, str]:
+ self.log.info("Starting Neptune cluster: %s", self.cluster_id)
+
+ # Check to make sure the cluster is not already available.
+ status = self.hook.get_cluster_status(self.cluster_id)
+ if status.lower() in NeptuneHook.AVAILABLE_STATES:
+ self.log.info("Neptune cluster %s is already available.",
self.cluster_id)
+ return {"db_cluster_id": self.cluster_id}
+
+ resp =
self.hook.conn.start_db_cluster(DBClusterIdentifier=self.cluster_id)
+ status = resp.get("DBClusters", {}).get("Status", "Unknown")
+
+ if self.deferrable:
+ self.log.info("Deferring for cluster start: %s", self.cluster_id)
+
+ self.defer(
+ trigger=NeptuneClusterAvailableTrigger(
+ aws_conn_id=self.aws_conn_id,
+ db_cluster_id=self.cluster_id,
+ waiter_delay=self.delay,
+ waiter_max_attempts=self.max_attempts,
+ ),
+ method_name="execute_complete",
+ )
+
+ elif self.wait_for_completion:
+ self.log.info("Waiting for Neptune cluster %s to start.",
self.cluster_id)
+ self.hook.wait_for_cluster_availability(self.cluster_id,
self.delay, self.max_attempts)
+
+ return {"db_cluster_id": self.cluster_id}
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, str]:
+ status = ""
+ cluster_id = ""
+
+ if event:
+ status = event.get("status", "")
+ cluster_id = event.get("cluster_id", "")
+
+ self.log.info("Neptune cluster %s available with status: %s",
cluster_id, status)
+
+ return {"db_cluster_id": cluster_id}
+
+
+class NeptuneStopDbClusterOperator(AwsBaseOperator[NeptuneHook]):
+ """
+ Stops an Amazon Neptune DB cluster.
+
+ Amazon Neptune Database is a serverless graph database designed for
superior scalability
+ and availability. Neptune Database provides built-in security, continuous
backups, and
+ integrations with other AWS services
+
+ .. seealso::
+ For more information on how to use this operator, take a look at the
guide:
+ :ref:`howto/operator:NeptuneStartDbClusterOperator`
+
+ :param db_cluster_id: The DB cluster identifier of the Neptune DB cluster
to be stopped.
+ :param wait_for_completion: Whether to wait for cluster to stop. (default:
True)
+ :param deferrable: If True, the operator will wait asynchronously for the
cluster to stop.
+ This implies waiting for completion. This mode requires aiobotocore
module to be installed.
+ (default: False)
+ :param waiter_delay: Time in seconds to wait between status checks.
+ :param waiter_max_attempts: Maximum number of attempts to check for job
completion.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ If this is ``None`` or empty then the default boto3 behaviour is used.
If
+ running Airflow in a distributed manner and aws_conn_id is None or
+ empty, then default boto3 configuration would be used (and must be
+ maintained on each worker node).
+ :param region_name: AWS region_name. If not specified then the default
boto3 behaviour is used.
+ :param botocore_config: Configuration dictionary (key-values) for botocore
client. See:
+
https://botocore.amazonaws.com/v1/documentation/api/latest/reference/config.html
+ :return: dictionary with Neptune cluster id
+ """
+
+ aws_hook_class = NeptuneHook
+ template_fields: Sequence[str] = aws_template_fields("cluster_id")
+
+ def __init__(
+ self,
+ db_cluster_id: str,
+ wait_for_completion: bool = True,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ deferrable: bool = conf.getboolean("operators", "default_deferrable",
fallback=False),
+ **kwargs,
+ ):
+ super().__init__(**kwargs)
+ self.cluster_id = db_cluster_id
+ self.wait_for_completion = wait_for_completion
+ self.deferrable = deferrable
+ self.delay = waiter_delay
+ self.max_attempts = waiter_max_attempts
+
+ def execute(self, context: Context) -> dict[str, str]:
+ self.log.info("Stopping Neptune cluster: %s", self.cluster_id)
+
+ # Check to make sure the cluster is not already stopped.
+ status = self.hook.get_cluster_status(self.cluster_id)
+ if status.lower() in NeptuneHook.STOPPED_STATES:
+ self.log.info("Neptune cluster %s is already stopped.",
self.cluster_id)
+ return {"db_cluster_id": self.cluster_id}
+
+ resp =
self.hook.conn.stop_db_cluster(DBClusterIdentifier=self.cluster_id)
+ status = resp.get("DBClusters", {}).get("Status", "Unknown")
+
+ if self.deferrable:
+ self.log.info("Deferring for cluster stop: %s", self.cluster_id)
+
+ self.defer(
+ trigger=NeptuneClusterStoppedTrigger(
+ aws_conn_id=self.aws_conn_id,
+ db_cluster_id=self.cluster_id,
+ waiter_delay=self.delay,
+ waiter_max_attempts=self.max_attempts,
+ ),
+ method_name="execute_complete",
+ )
+
+ elif self.wait_for_completion:
+ self.log.info("Waiting for Neptune cluster %s to start.",
self.cluster_id)
+ self.hook.wait_for_cluster_stopped(self.cluster_id, self.delay,
self.max_attempts)
+
+ return {"db_cluster_id": self.cluster_id}
+
+ def execute_complete(self, context: Context, event: dict[str, Any] | None
= None) -> dict[str, str]:
+ status = ""
+ cluster_id = ""
+
+ if event:
+ status = event.get("status", "")
+ cluster_id = event.get("cluster_id", "")
+
+ self.log.info("Neptune cluster %s stopped with status: %s",
cluster_id, status)
+
+ return {"db_cluster_id": cluster_id}
diff --git a/airflow/providers/amazon/aws/triggers/neptune.py
b/airflow/providers/amazon/aws/triggers/neptune.py
new file mode 100644
index 0000000000..4b7d34f542
--- /dev/null
+++ b/airflow/providers/amazon/aws/triggers/neptune.py
@@ -0,0 +1,115 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import TYPE_CHECKING
+
+from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
+from airflow.providers.amazon.aws.triggers.base import AwsBaseWaiterTrigger
+
+if TYPE_CHECKING:
+ from airflow.providers.amazon.aws.hooks.base_aws import AwsGenericHook
+
+
+class NeptuneClusterAvailableTrigger(AwsBaseWaiterTrigger):
+ """
+ Triggers when a Neptune Cluster is available.
+
+ :param db_cluster_id: Cluster ID to poll.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: AWS region name (example: us-east-1)
+ """
+
+ def __init__(
+ self,
+ *,
+ db_cluster_id: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ aws_conn_id: str | None = None,
+ region_name: str | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ serialized_fields={"db_cluster_id": db_cluster_id},
+ waiter_name="cluster_available",
+ waiter_args={"DBClusterIdentifier": db_cluster_id},
+ failure_message="Failed to start Neptune cluster",
+ status_message="Status of Neptune cluster is",
+ status_queries=["DBClusters[0].Status"],
+ return_key="db_cluster_id",
+ return_value=db_cluster_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ **kwargs,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return NeptuneHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
+
+
+class NeptuneClusterStoppedTrigger(AwsBaseWaiterTrigger):
+ """
+ Triggers when a Neptune Cluster is stopped.
+
+ :param db_cluster_id: Cluster ID to poll.
+ :param waiter_delay: The amount of time in seconds to wait between
attempts.
+ :param waiter_max_attempts: The maximum number of attempts to be made.
+ :param aws_conn_id: The Airflow connection used for AWS credentials.
+ :param region_name: AWS region name (example: us-east-1)
+ """
+
+ def __init__(
+ self,
+ *,
+ db_cluster_id: str,
+ waiter_delay: int = 30,
+ waiter_max_attempts: int = 60,
+ aws_conn_id: str | None = None,
+ region_name: str | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ serialized_fields={"db_cluster_id": db_cluster_id},
+ waiter_name="cluster_stopped",
+ waiter_args={"DBClusterIdentifier": db_cluster_id},
+ failure_message="Failed to stop Neptune cluster",
+ status_message="Status of Neptune cluster is",
+ status_queries=["DBClusters[0].Status"],
+ return_key="db_cluster_id",
+ return_value=db_cluster_id,
+ waiter_delay=waiter_delay,
+ waiter_max_attempts=waiter_max_attempts,
+ aws_conn_id=aws_conn_id,
+ **kwargs,
+ )
+
+ def hook(self) -> AwsGenericHook:
+ return NeptuneHook(
+ aws_conn_id=self.aws_conn_id,
+ region_name=self.region_name,
+ verify=self.verify,
+ config=self.botocore_config,
+ )
diff --git a/airflow/providers/amazon/aws/waiters/neptune.json
b/airflow/providers/amazon/aws/waiters/neptune.json
new file mode 100644
index 0000000000..d71ee9a75b
--- /dev/null
+++ b/airflow/providers/amazon/aws/waiters/neptune.json
@@ -0,0 +1,85 @@
+{
+ "version": 2,
+ "waiters": {
+ "cluster_available": {
+ "operation": "DescribeDBClusters",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "available",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "deleting",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "inaccessible-encryption-credentials",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected":
"inaccessible-encryption-credentials-recoverable",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "migration-failed",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "stopped",
+ "state": "retry"
+ }
+ ]
+ },
+ "cluster_stopped": {
+ "operation": "DescribeDBClusters",
+ "delay": 30,
+ "maxAttempts": 60,
+ "acceptors": [
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "stopped",
+ "state": "success"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "deleting",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "inaccessible-encryption-credentials",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected":
"inaccessible-encryption-credentials-recoverable",
+ "state": "failure"
+ },
+ {
+ "matcher": "path",
+ "argument": "DBClusters[0].Status",
+ "expected": "migration-failed",
+ "state": "failure"
+ }
+ ]
+ }
+ }
+}
diff --git a/airflow/providers/amazon/provider.yaml
b/airflow/providers/amazon/provider.yaml
index 5bf6a45100..65603f1098 100644
--- a/airflow/providers/amazon/provider.yaml
+++ b/airflow/providers/amazon/provider.yaml
@@ -337,6 +337,12 @@ integrations:
external-doc-url: https://aws.amazon.com/verified-permissions/
logo: /integration-logos/aws/Amazon-Verified-Permissions.png
tags: [aws]
+ - integration-name: Amazon Neptune
+ external-doc-url: https://aws.amazon.com/neptune/
+ logo: /integration-logos/aws/Amazon-Neptune_64.png
+ how-to-guide:
+ - /docs/apache-airflow-providers-amazon/operators/neptune.rst
+ tags: [aws]
operators:
- integration-name: Amazon Athena
@@ -416,6 +422,9 @@ operators:
- integration-name: AWS Glue DataBrew
python-modules:
- airflow.providers.amazon.aws.operators.glue_databrew
+ - integration-name: Amazon Neptune
+ python-modules:
+ - airflow.providers.amazon.aws.operators.neptune
sensors:
- integration-name: Amazon Athena
@@ -602,6 +611,9 @@ hooks:
- integration-name: Amazon Verified Permissions
python-modules:
- airflow.providers.amazon.aws.hooks.verified_permissions
+ - integration-name: Amazon Neptune
+ python-modules:
+ - airflow.providers.amazon.aws.hooks.neptune
triggers:
- integration-name: Amazon Web Services
@@ -654,6 +666,9 @@ triggers:
- integration-name: AWS Glue DataBrew
python-modules:
- airflow.providers.amazon.aws.triggers.glue_databrew
+ - integration-name: Amazon Neptune
+ python-modules:
+ - airflow.providers.amazon.aws.triggers.neptune
transfers:
- source-integration-name: Amazon DynamoDB
diff --git a/docs/apache-airflow-providers-amazon/operators/neptune.rst
b/docs/apache-airflow-providers-amazon/operators/neptune.rst
new file mode 100644
index 0000000000..98c0d7dd57
--- /dev/null
+++ b/docs/apache-airflow-providers-amazon/operators/neptune.rst
@@ -0,0 +1,77 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+==============
+Amazon Neptune
+==============
+
+`Amazon Neptune Database <https://aws.amazon.com/neptune/>`__ is a serverless
graph database designed
+for superior scalability and availability. Neptune Database provides built-in
security,
+continuous backups, and integrations with other AWS services.
+
+Prerequisite Tasks
+------------------
+
+.. include:: ../_partials/prerequisite_tasks.rst
+
+Generic Parameters
+------------------
+
+.. include:: ../_partials/generic_parameters.rst
+
+Operators
+---------
+
+.. _howto/operator:NeptuneStartDbClusterOperator:
+
+Start a Neptune database cluster
+================================
+
+To start a existing Neptune database cluster, you can use
+:class:`~airflow.providers.amazon.aws.operators.neptune.StartNeptuneDbClusterOperator`.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter. This requires
+the aiobotocore module to be installed.
+
+.. note::
+ This operator only starts an existing Neptune database cluster, it does
not create a cluster.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_neptune.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_start_neptune_cluster]
+ :end-before: [END howto_operator_start_neptune_cluster]
+
+.. _howto/operator:StopNeptuneDbClusterOperator:
+
+Stop a Neptune database cluster
+===============================
+
+To stop a running Neptune database cluster, you can use
+:class:`~airflow.providers.amazon.aws.operators.neptune.StartNeptuneDbClusterOperator`.
+This operator can be run in deferrable mode by passing ``deferrable=True`` as
a parameter. This requires
+the aiobotocore module to be installed.
+
+.. exampleinclude:: /../../tests/system/providers/amazon/aws/example_neptune.py
+ :language: python
+ :dedent: 4
+ :start-after: [START howto_operator_stop_neptune_cluster]
+ :end-before: [END howto_operator_stop_neptune_cluster]
+
+Reference
+---------
+
+* `AWS boto3 library documentation for Neptune
<https://boto3.amazonaws.com/v1/documentation/api/latest/reference/services/neptune.html>`__
diff --git a/docs/integration-logos/aws/Amazon-Neptune_64.png
b/docs/integration-logos/aws/Amazon-Neptune_64.png
new file mode 100644
index 0000000000..8dd8f1e80d
Binary files /dev/null and b/docs/integration-logos/aws/Amazon-Neptune_64.png
differ
diff --git a/tests/providers/amazon/aws/hooks/test_neptune.py
b/tests/providers/amazon/aws/hooks/test_neptune.py
new file mode 100644
index 0000000000..bf8372190b
--- /dev/null
+++ b/tests/providers/amazon/aws/hooks/test_neptune.py
@@ -0,0 +1,52 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Generator
+
+import pytest
+from moto import mock_neptune
+
+from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
+
+
[email protected]
+def neptune_hook() -> Generator[NeptuneHook, None, None]:
+ """Returns a NeptuneHook mocked with moto"""
+ with mock_neptune():
+ yield NeptuneHook(aws_conn_id="aws_default")
+
+
[email protected]
+def neptune_cluster_id(neptune_hook: NeptuneHook) -> str:
+ """Returns Neptune cluster ID"""
+ resp = neptune_hook.conn.create_db_cluster(
+ DBClusterIdentifier="test-cluster",
+ Engine="neptune",
+ )
+
+ return resp["DBCluster"]["DBClusterIdentifier"]
+
+
+class TestNeptuneHook:
+ def test_get_conn_returns_a_boto3_connection(self):
+ hook = NeptuneHook(aws_conn_id="aws_default")
+ assert hook.get_conn() is not None
+
+ def test_get_cluster_status(self, neptune_hook: NeptuneHook,
neptune_cluster_id):
+ assert neptune_hook.get_cluster_status(neptune_cluster_id) is not None
diff --git a/tests/providers/amazon/aws/operators/test_neptune.py
b/tests/providers/amazon/aws/operators/test_neptune.py
new file mode 100644
index 0000000000..af7dc289d4
--- /dev/null
+++ b/tests/providers/amazon/aws/operators/test_neptune.py
@@ -0,0 +1,152 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from typing import Generator
+from unittest import mock
+
+import pytest
+from moto import mock_neptune
+
+from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
+from airflow.providers.amazon.aws.operators.neptune import (
+ NeptuneStartDbClusterOperator,
+ NeptuneStopDbClusterOperator,
+)
+
+CLUSTER_ID = "test_cluster"
+
+EXPECTED_RESPONSE = {"db_cluster_id": CLUSTER_ID}
+
+
[email protected]
+def hook() -> Generator[NeptuneHook, None, None]:
+ with mock_neptune():
+ yield NeptuneHook(aws_conn_id="aws_default")
+
+
[email protected]
+def _create_cluster(hook: NeptuneHook):
+ hook.conn.create_db_cluster(
+ DBClusterIdentifier=CLUSTER_ID,
+ Engine="neptune",
+ )
+ if not hook.conn.describe_db_clusters()["DBClusters"]:
+ raise ValueError("AWS not properly mocked")
+
+
+class TestNeptuneStartClusterOperator:
+ @mock.patch.object(NeptuneHook, "conn")
+ @mock.patch.object(NeptuneHook, "get_waiter")
+ def test_start_cluster_wait_for_completion(self, mock_hook_get_waiter,
mock_conn):
+ operator = NeptuneStartDbClusterOperator(
+ task_id="task_test",
+ db_cluster_id=CLUSTER_ID,
+ deferrable=False,
+ wait_for_completion=True,
+ aws_conn_id="aws_default",
+ )
+
+ resp = operator.execute(None)
+ mock_hook_get_waiter.assert_called_once_with("cluster_available")
+ assert resp == EXPECTED_RESPONSE
+
+ @mock.patch.object(NeptuneHook, "conn")
+ @mock.patch.object(NeptuneHook, "get_waiter")
+ def test_start_cluster_no_wait(self, mock_hook_get_waiter, mock_conn):
+ operator = NeptuneStartDbClusterOperator(
+ task_id="task_test",
+ db_cluster_id=CLUSTER_ID,
+ deferrable=False,
+ wait_for_completion=False,
+ aws_conn_id="aws_default",
+ )
+
+ resp = operator.execute(None)
+ mock_hook_get_waiter.assert_not_called()
+ assert resp == EXPECTED_RESPONSE
+
+ @mock.patch.object(NeptuneHook, "conn")
+ @mock.patch.object(NeptuneHook, "get_cluster_status")
+ @mock.patch.object(NeptuneHook, "get_waiter")
+ def test_start_cluster_cluster_available(self, mock_waiter,
mock_get_cluster_status, mock_conn):
+ mock_get_cluster_status.return_value = "available"
+ operator = NeptuneStartDbClusterOperator(
+ task_id="task_test",
+ db_cluster_id=CLUSTER_ID,
+ deferrable=False,
+ wait_for_completion=True,
+ aws_conn_id="aws_default",
+ )
+
+ resp = operator.execute(None)
+
+ mock_conn.start_db_cluster.assert_not_called()
+ mock_waiter.assert_not_called()
+ assert resp == {"db_cluster_id": CLUSTER_ID}
+
+
+class TestNeptuneStopClusterOperator:
+ @mock.patch.object(NeptuneHook, "conn")
+ @mock.patch.object(NeptuneHook, "get_waiter")
+ def test_stop_cluster_wait_for_completion(self, mock_hook_get_waiter,
mock_conn):
+ operator = NeptuneStopDbClusterOperator(
+ task_id="task_test",
+ db_cluster_id=CLUSTER_ID,
+ deferrable=False,
+ wait_for_completion=True,
+ aws_conn_id="aws_default",
+ )
+
+ resp = operator.execute(None)
+ mock_hook_get_waiter.assert_called_once_with("cluster_stopped")
+ assert resp == EXPECTED_RESPONSE
+
+ @mock.patch.object(NeptuneHook, "conn")
+ @mock.patch.object(NeptuneHook, "get_waiter")
+ def test_stop_cluster_no_wait(self, mock_hook_get_waiter, mock_conn):
+ operator = NeptuneStopDbClusterOperator(
+ task_id="task_test",
+ db_cluster_id=CLUSTER_ID,
+ deferrable=False,
+ wait_for_completion=False,
+ aws_conn_id="aws_default",
+ )
+
+ resp = operator.execute(None)
+ mock_hook_get_waiter.assert_not_called()
+ assert resp == EXPECTED_RESPONSE
+
+ @mock.patch.object(NeptuneHook, "conn")
+ @mock.patch.object(NeptuneHook, "get_cluster_status")
+ @mock.patch.object(NeptuneHook, "get_waiter")
+ def test_stop_cluster_cluster_stopped(self, mock_waiter,
mock_get_cluster_status, mock_conn):
+ mock_get_cluster_status.return_value = "stopped"
+ operator = NeptuneStopDbClusterOperator(
+ task_id="task_test",
+ db_cluster_id=CLUSTER_ID,
+ deferrable=False,
+ wait_for_completion=True,
+ aws_conn_id="aws_default",
+ )
+
+ resp = operator.execute(None)
+
+ mock_conn.stop_db_cluster.assert_not_called()
+ mock_waiter.assert_not_called()
+ assert resp == {"db_cluster_id": CLUSTER_ID}
diff --git a/tests/providers/amazon/aws/triggers/test_neptune.py
b/tests/providers/amazon/aws/triggers/test_neptune.py
new file mode 100644
index 0000000000..3664e1dedd
--- /dev/null
+++ b/tests/providers/amazon/aws/triggers/test_neptune.py
@@ -0,0 +1,82 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest import mock
+from unittest.mock import AsyncMock
+
+import pytest
+
+from airflow.providers.amazon.aws.triggers.neptune import (
+ NeptuneClusterAvailableTrigger,
+ NeptuneClusterStoppedTrigger,
+)
+from airflow.triggers.base import TriggerEvent
+
+CLUSTER_ID = "test-cluster"
+
+
+class TestNeptuneClusterAvailableTrigger:
+ def test_serialization(self):
+ """
+ Asserts that the TaskStateTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = NeptuneClusterAvailableTrigger(db_cluster_id=CLUSTER_ID)
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.neptune.NeptuneClusterAvailableTrigger"
+ assert "db_cluster_id" in kwargs
+ assert kwargs["db_cluster_id"] == CLUSTER_ID
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = "available"
+ mock_get_waiter().wait = AsyncMock()
+ trigger = NeptuneClusterAvailableTrigger(db_cluster_id=CLUSTER_ID)
+ generator = trigger.run()
+ resp = await generator.asend(None)
+
+ assert resp == TriggerEvent({"status": "success", "db_cluster_id":
CLUSTER_ID})
+ assert mock_get_waiter().wait.call_count == 1
+
+
+class TestNeptuneClusterStoppedTrigger:
+ def test_serialization(self):
+ """
+ Asserts that the TaskStateTrigger correctly serializes its arguments
+ and classpath.
+ """
+ trigger = NeptuneClusterStoppedTrigger(db_cluster_id=CLUSTER_ID)
+ classpath, kwargs = trigger.serialize()
+ assert classpath ==
"airflow.providers.amazon.aws.triggers.neptune.NeptuneClusterStoppedTrigger"
+ assert "db_cluster_id" in kwargs
+ assert kwargs["db_cluster_id"] == CLUSTER_ID
+
+ @pytest.mark.asyncio
+
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.get_waiter")
+
@mock.patch("airflow.providers.amazon.aws.hooks.neptune.NeptuneHook.async_conn")
+ async def test_run_success(self, mock_async_conn, mock_get_waiter):
+ mock_async_conn.__aenter__.return_value = "stopped"
+ mock_get_waiter().wait = AsyncMock()
+ trigger = NeptuneClusterStoppedTrigger(db_cluster_id=CLUSTER_ID)
+ generator = trigger.run()
+ resp = await generator.asend(None)
+
+ assert resp == TriggerEvent({"status": "success", "db_cluster_id":
CLUSTER_ID})
+ assert mock_get_waiter().wait.call_count == 1
diff --git a/tests/providers/amazon/aws/waiters/test_neptune.py
b/tests/providers/amazon/aws/waiters/test_neptune.py
new file mode 100644
index 0000000000..118690e6c9
--- /dev/null
+++ b/tests/providers/amazon/aws/waiters/test_neptune.py
@@ -0,0 +1,89 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+from unittest import mock
+
+import boto3
+import botocore
+import pytest
+
+from airflow.providers.amazon.aws.hooks.neptune import NeptuneHook
+
+
+class TestCustomNeptuneWaiters:
+ """Test waiters from ``amazon/aws/waiters/neptune.json``."""
+
+ @pytest.fixture(autouse=True)
+ def setup_test_cases(self, monkeypatch):
+ self.client = boto3.client("neptune", region_name="eu-west-3")
+ monkeypatch.setattr(NeptuneHook, "conn", self.client)
+
+ def test_service_waiters(self):
+ hook_waiters = NeptuneHook(aws_conn_id=None).list_waiters()
+ assert "cluster_available" in hook_waiters
+
+ @pytest.fixture()
+ def mock_describe_clusters(self):
+ with mock.patch.object(self.client, "describe_db_clusters") as m:
+ yield m
+
+ @staticmethod
+ def get_status_response(status):
+ return {"DBClusters": [{"Status": status}]}
+
+ def test_cluster_available(self, mock_describe_clusters):
+ mock_describe_clusters.return_value = {"DBClusters": [{"Status":
"available"}]}
+ waiter = NeptuneHook(aws_conn_id=None).get_waiter("cluster_available")
+ waiter.wait(DBClusterIdentifier="test_cluster")
+
+ def test_cluster_available_failed(self, mock_describe_clusters):
+ with pytest.raises(botocore.exceptions.WaiterError):
+ mock_describe_clusters.return_value = {"DBClusters": [{"Status":
"migration-failed"}]}
+ waiter =
NeptuneHook(aws_conn_id=None).get_waiter("cluster_available")
+ waiter.wait(DBClusterIdentifier="test_cluster")
+
+ def test_starting_up(self, mock_describe_clusters):
+ """Test job succeeded"""
+ mock_describe_clusters.side_effect = [
+ self.get_status_response("stopped"),
+ self.get_status_response("starting"),
+ self.get_status_response("available"),
+ ]
+ waiter = NeptuneHook(aws_conn_id=None).get_waiter("cluster_available")
+ waiter.wait(cluster_identifier="test_cluster", WaiterConfig={"Delay":
0.2, "MaxAttempts": 4})
+
+ def test_cluster_stopped(self, mock_describe_clusters):
+ mock_describe_clusters.return_value = {"DBClusters": [{"Status":
"stopped"}]}
+ waiter = NeptuneHook(aws_conn_id=None).get_waiter("cluster_stopped")
+ waiter.wait(DBClusterIdentifier="test_cluster")
+
+ def test_cluster_stopped_failed(self, mock_describe_clusters):
+ with pytest.raises(botocore.exceptions.WaiterError):
+ mock_describe_clusters.return_value = {"DBClusters": [{"Status":
"migration-failed"}]}
+ waiter =
NeptuneHook(aws_conn_id=None).get_waiter("cluster_stopped")
+ waiter.wait(DBClusterIdentifier="test_cluster")
+
+ def test_stopping(self, mock_describe_clusters):
+ mock_describe_clusters.side_effect = [
+ self.get_status_response("available"),
+ self.get_status_response("stopping"),
+ self.get_status_response("stopped"),
+ ]
+ waiter = NeptuneHook(aws_conn_id=None).get_waiter("cluster_stopped")
+ waiter.wait(cluster_identifier="test_cluster", WaiterConfig={"Delay":
0.2, "MaxAttempts": 4})
diff --git a/tests/system/providers/amazon/aws/example_neptune.py
b/tests/system/providers/amazon/aws/example_neptune.py
new file mode 100644
index 0000000000..fc9d4226b5
--- /dev/null
+++ b/tests/system/providers/amazon/aws/example_neptune.py
@@ -0,0 +1,68 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import pendulum
+
+from airflow.models.baseoperator import chain
+from airflow.models.dag import DAG
+from airflow.providers.amazon.aws.operators.neptune import (
+ NeptuneStartDbClusterOperator,
+ NeptuneStopDbClusterOperator,
+)
+from tests.system.providers.amazon.aws.utils import SystemTestContextBuilder
+
+DAG_ID = "example_neptune"
+# This test requires an existing Neptune cluster.
+CLUSTER_ID = "CLUSTER_ID"
+
+sys_test_context_task =
SystemTestContextBuilder().add_variable(CLUSTER_ID).build()
+
+with DAG(DAG_ID, schedule="@once", start_date=pendulum.datetime(2024, 1, 1,
tz="UTC"), catchup=False) as dag:
+ test_context = sys_test_context_task()
+ env_id = test_context["ENV_ID"]
+ cluster_id = test_context["CLUSTER_ID"]
+
+ # [START howto_operator_start_neptune_cluster]
+ start_cluster = NeptuneStartDbClusterOperator(
+ task_id="start_task", db_cluster_id=cluster_id, deferrable=True
+ )
+ # [END howto_operator_start_neptune_cluster]
+
+ # [START howto_operator_stop_neptune_cluster]
+ stop_cluster = NeptuneStopDbClusterOperator(
+ task_id="stop_task", db_cluster_id=cluster_id, deferrable=True
+ )
+ # [END howto_operator_stop_neptune_cluster]
+
+ chain(
+ # TEST SETUP
+ test_context,
+ # TEST BODY
+ start_cluster,
+ stop_cluster,
+ )
+ from tests.system.utils.watcher import watcher
+
+ # This test needs watcher in order to properly mark success/failure
+ # when "tearDown" task with trigger rule is part of the DAG
+ list(dag.tasks) >> watcher()
+
+from tests.system.utils import get_test_run # noqa: E402
+
+# Needed to run the example DAG with pytest (see:
tests/system/README.md#run_via_pytest)
+test_run = get_test_run(dag)