This is an automated email from the ASF dual-hosted git repository.

eladkal pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 78b179f0c1 Switch to `mongo_conn_id` argument into the MongoHook 
constructor (#36896)
78b179f0c1 is described below

commit 78b179f0c1cba0dfb38d3db7df037b85e208d19c
Author: Andrey Anshin <[email protected]>
AuthorDate: Sat Jan 20 01:45:26 2024 +0400

    Switch to `mongo_conn_id` argument into the MongoHook constructor (#36896)
---
 airflow/providers/mongo/hooks/mongo.py             | 17 ++++--
 airflow/providers/mongo/sensors/mongo.py           |  2 +-
 .../providers/mongo/sensors/test_mongo.py          | 24 ++++++---
 tests/providers/mongo/hooks/test_mongo.py          | 60 +++++++++++++++++-----
 4 files changed, 78 insertions(+), 25 deletions(-)

diff --git a/airflow/providers/mongo/hooks/mongo.py 
b/airflow/providers/mongo/hooks/mongo.py
index d7fb7bb53b..9eacfb6817 100644
--- a/airflow/providers/mongo/hooks/mongo.py
+++ b/airflow/providers/mongo/hooks/mongo.py
@@ -18,6 +18,7 @@
 """Hook for Mongo DB."""
 from __future__ import annotations
 
+import warnings
 from ssl import CERT_NONE
 from typing import TYPE_CHECKING, Any, overload
 from urllib.parse import quote_plus, urlunsplit
@@ -25,6 +26,7 @@ from urllib.parse import quote_plus, urlunsplit
 import pymongo
 from pymongo import MongoClient, ReplaceOne
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.hooks.base import BaseHook
 
 if TYPE_CHECKING:
@@ -57,10 +59,19 @@ class MongoHook(BaseHook):
     conn_type = "mongo"
     hook_name = "MongoDB"
 
-    def __init__(self, conn_id: str = default_conn_name, *args, **kwargs) -> 
None:
+    def __init__(self, mongo_conn_id: str = default_conn_name, *args, 
**kwargs) -> None:
         super().__init__(logger_name=kwargs.pop("logger_name", None))
-        self.mongo_conn_id = conn_id
-        self.connection = self.get_connection(conn_id)
+        if conn_id := kwargs.pop("conn_id", None):
+            warnings.warn(
+                "Parameter `conn_id` is deprecated and will be removed in a 
future releases. "
+                "Please use `mongo_conn_id` instead.",
+                AirflowProviderDeprecationWarning,
+                stacklevel=2,
+            )
+            mongo_conn_id = conn_id
+
+        self.mongo_conn_id = mongo_conn_id
+        self.connection = self.get_connection(self.mongo_conn_id)
         self.extras = self.connection.extra_dejson.copy()
         self.client: MongoClient | None = None
         self.uri = self._create_uri()
diff --git a/airflow/providers/mongo/sensors/mongo.py 
b/airflow/providers/mongo/sensors/mongo.py
index cd79591986..724ab5b4c6 100644
--- a/airflow/providers/mongo/sensors/mongo.py
+++ b/airflow/providers/mongo/sensors/mongo.py
@@ -62,5 +62,5 @@ class MongoSensor(BaseSensorOperator):
         self.log.info(
             "Sensor check existence of the document that matches the following 
query: %s", self.query
         )
-        hook = MongoHook(self.mongo_conn_id)
+        hook = MongoHook(mongo_conn_id=self.mongo_conn_id)
         return hook.find(self.collection, self.query, mongo_db=self.mongo_db, 
find_one=True) is not None
diff --git a/tests/integration/providers/mongo/sensors/test_mongo.py 
b/tests/integration/providers/mongo/sensors/test_mongo.py
index 98eaec5273..523e6e1c04 100644
--- a/tests/integration/providers/mongo/sensors/test_mongo.py
+++ b/tests/integration/providers/mongo/sensors/test_mongo.py
@@ -23,22 +23,32 @@ from airflow.models import Connection
 from airflow.models.dag import DAG
 from airflow.providers.mongo.hooks.mongo import MongoHook
 from airflow.providers.mongo.sensors.mongo import MongoSensor
-from airflow.utils import db, timezone
+from airflow.utils import timezone
 
 DEFAULT_DATE = timezone.datetime(2017, 1, 1)
 
 
[email protected](scope="module", autouse=True)
+def mongo_connections():
+    """Create MongoDB connections which use for testing purpose."""
+    connections = [
+        Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", 
port=27017),
+        Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", 
port=27017, schema="test"),
+    ]
+
+    with pytest.MonkeyPatch.context() as mp:
+        for conn in connections:
+            mp.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.as_json())
+        yield
+
+
 @pytest.mark.integration("mongo")
 class TestMongoSensor:
     def setup_method(self):
-        db.merge_conn(
-            Connection(conn_id="mongo_test", conn_type="mongo", host="mongo", 
port=27017, schema="test")
-        )
-
         args = {"owner": "airflow", "start_date": DEFAULT_DATE}
         self.dag = DAG("test_dag_id", default_args=args)
 
-        hook = MongoHook("mongo_test")
+        hook = MongoHook(mongo_conn_id="mongo_test")
         hook.insert_one("foo", {"bar": "baz"})
 
         self.sensor = MongoSensor(
@@ -53,7 +63,7 @@ class TestMongoSensor:
         assert self.sensor.poke(None)
 
     def test_sensor_with_db(self):
-        hook = MongoHook("mongo_test")
+        hook = MongoHook(mongo_conn_id="mongo_test")
         hook.insert_one("nontest", {"1": "2"}, mongo_db="nontest")
 
         sensor = MongoSensor(
diff --git a/tests/providers/mongo/hooks/test_mongo.py 
b/tests/providers/mongo/hooks/test_mongo.py
index 718a5381b7..d546c93e9d 100644
--- a/tests/providers/mongo/hooks/test_mongo.py
+++ b/tests/providers/mongo/hooks/test_mongo.py
@@ -18,14 +18,15 @@
 from __future__ import annotations
 
 import importlib
+import warnings
 from typing import TYPE_CHECKING
 
 import pymongo
 import pytest
 
+from airflow.exceptions import AirflowProviderDeprecationWarning
 from airflow.models import Connection
 from airflow.providers.mongo.hooks.mongo import MongoHook
-from airflow.utils import db
 
 pytestmark = pytest.mark.db_test
 
@@ -40,14 +41,36 @@ except ImportError:
     mongomock = None
 
 
[email protected](scope="module", autouse=True)
+def mongo_connections():
+    """Create MongoDB connections which use for testing purpose."""
+    connections = [
+        Connection(conn_id="mongo_default", conn_type="mongo", host="mongo", 
port=27017),
+        Connection(
+            conn_id="mongo_default_with_srv",
+            conn_type="mongo",
+            host="mongo",
+            port=27017,
+            extra='{"srv": true}',
+        ),
+        # Mongo establishes connection during initialization, so we need to 
have this connection
+        Connection(conn_id="fake_connection", conn_type="mongo", host="mongo", 
port=27017),
+    ]
+
+    with pytest.MonkeyPatch.context() as mp:
+        for conn in connections:
+            mp.setenv(f"AIRFLOW_CONN_{conn.conn_id.upper()}", conn.as_json())
+        yield
+
+
 class MongoHookTest(MongoHook):
     """
     Extending hook so that a mockmongo collection object can be passed in
     to get_collection()
     """
 
-    def __init__(self, conn_id="mongo_default", *args, **kwargs):
-        super().__init__(conn_id=conn_id, *args, **kwargs)
+    def __init__(self, mongo_conn_id="mongo_default", *args, **kwargs):
+        super().__init__(mongo_conn_id=mongo_conn_id, *args, **kwargs)
 
     def get_collection(self, mock_collection, mongo_db=None):
         return mock_collection
@@ -56,24 +79,33 @@ class MongoHookTest(MongoHook):
 @pytest.mark.skipif(mongomock is None, reason="mongomock package not present")
 class TestMongoHook:
     def setup_method(self):
-        self.hook = MongoHookTest(conn_id="mongo_default", mongo_db="default")
+        self.hook = MongoHookTest(mongo_conn_id="mongo_default")
         self.conn = self.hook.get_conn()
-        db.merge_conn(
-            Connection(
-                conn_id="mongo_default_with_srv",
-                conn_type="mongo",
-                host="mongo",
-                port=27017,
-                extra='{"srv": true}',
+
+    def test_mongo_conn_id(self):
+        with warnings.catch_warnings():
+            warnings.simplefilter("error", 
category=AirflowProviderDeprecationWarning)
+            # Use default "mongo_default"
+            assert MongoHook().mongo_conn_id == "mongo_default"
+            # Positional argument
+            assert MongoHook("fake_connection").mongo_conn_id == 
"fake_connection"
+
+        warning_message = "Parameter `conn_id` is deprecated"
+        with pytest.warns(AirflowProviderDeprecationWarning, 
match=warning_message):
+            assert MongoHook(conn_id="fake_connection").mongo_conn_id == 
"fake_connection"
+
+        with pytest.warns(AirflowProviderDeprecationWarning, 
match=warning_message):
+            assert (
+                MongoHook(conn_id="fake_connection", 
mongo_conn_id="foo-bar").mongo_conn_id
+                == "fake_connection"
             )
-        )
 
     def test_get_conn(self):
         assert self.hook.connection.port == 27017
         assert isinstance(self.conn, pymongo.MongoClient)
 
     def test_srv(self):
-        hook = MongoHook(conn_id="mongo_default_with_srv")
+        hook = MongoHook(mongo_conn_id="mongo_default_with_srv")
         assert hook.uri.startswith("mongodb+srv://")
 
     def test_insert_one(self):
@@ -333,7 +365,7 @@ class TestMongoHook:
 
 
 def test_context_manager():
-    with MongoHook(conn_id="mongo_default", mongo_db="default") as ctx_hook:
+    with MongoHook(mongo_conn_id="mongo_default") as ctx_hook:
         ctx_hook.get_conn()
 
         assert isinstance(ctx_hook, MongoHook)

Reply via email to