This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new dce9796861 Add init_command parameter to MySqlHook (#33359)
dce9796861 is described below

commit dce9796861e0a535952f79b0e2a7d5a012fcc01b
Author: Alex Begg <[email protected]>
AuthorDate: Thu Aug 17 22:58:51 2023 -0700

    Add init_command parameter to MySqlHook (#33359)
    
    This allows the addition of `init_command` as a `MySqlHook` parameter.
    
    For example, to set the MySQL session's `time_zone` to UTC you can set the 
`init_command` to `SET time_zone = '+00:00';`.
---
 airflow/providers/mysql/hooks/mysql.py                     |  6 ++++++
 tests/providers/mysql/hooks/test_mysql.py                  |  9 +++++++++
 tests/providers/mysql/hooks/test_mysql_connector_python.py | 11 +++++++++++
 3 files changed, 26 insertions(+)

diff --git a/airflow/providers/mysql/hooks/mysql.py 
b/airflow/providers/mysql/hooks/mysql.py
index fa011ed35b..3fe86652cc 100644
--- a/airflow/providers/mysql/hooks/mysql.py
+++ b/airflow/providers/mysql/hooks/mysql.py
@@ -58,6 +58,7 @@ class MySqlHook(DbApiHook):
     :param schema: The MySQL database schema to connect to.
     :param connection: The :ref:`MySQL connection id <howto/connection:mysql>` 
used for MySQL credentials.
     :param local_infile: Boolean flag determining if local_infile should be 
used
+    :param init_command: Initial command to issue to MySQL server upon 
connection
     """
 
     conn_name_attr = "mysql_conn_id"
@@ -71,6 +72,7 @@ class MySqlHook(DbApiHook):
         self.schema = kwargs.pop("schema", None)
         self.connection = kwargs.pop("connection", None)
         self.local_infile = kwargs.pop("local_infile", False)
+        self.init_command = kwargs.pop("init_command", None)
 
     def set_autocommit(self, conn: MySQLConnectionTypes, autocommit: bool) -> 
None:
         """
@@ -144,6 +146,8 @@ class MySqlHook(DbApiHook):
             conn_config["unix_socket"] = conn.extra_dejson["unix_socket"]
         if self.local_infile:
             conn_config["local_infile"] = 1
+        if self.init_command:
+            conn_config["init_command"] = self.init_command
         return conn_config
 
     def _get_conn_config_mysql_connector_python(self, conn: Connection) -> 
dict:
@@ -157,6 +161,8 @@ class MySqlHook(DbApiHook):
 
         if self.local_infile:
             conn_config["allow_local_infile"] = True
+        if self.init_command:
+            conn_config["init_command"] = self.init_command
         # Ref: 
https://dev.mysql.com/doc/connector-python/en/connector-python-connectargs.html
         for key, value in conn.extra_dejson.items():
             if key.startswith("ssl_"):
diff --git a/tests/providers/mysql/hooks/test_mysql.py 
b/tests/providers/mysql/hooks/test_mysql.py
index 81773e9e3f..2b6c81df5d 100644
--- a/tests/providers/mysql/hooks/test_mysql.py
+++ b/tests/providers/mysql/hooks/test_mysql.py
@@ -181,6 +181,15 @@ class TestMySqlHookConn:
             read_default_group="enable-cleartext-plugin",
         )
 
+    @mock.patch("MySQLdb.connect")
+    def test_get_conn_init_command(self, mock_connect):
+        self.db_hook.init_command = "SET time_zone = '+00:00';"
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        args, kwargs = mock_connect.call_args
+        assert args == ()
+        assert kwargs["init_command"] == "SET time_zone = '+00:00';"
+
 
 class MockMySQLConnectorConnection:
     DEFAULT_AUTOCOMMIT = "default"
diff --git a/tests/providers/mysql/hooks/test_mysql_connector_python.py 
b/tests/providers/mysql/hooks/test_mysql_connector_python.py
index d4c9ffc0b3..5d2f0d1613 100644
--- a/tests/providers/mysql/hooks/test_mysql_connector_python.py
+++ b/tests/providers/mysql/hooks/test_mysql_connector_python.py
@@ -79,3 +79,14 @@ class TestMySqlHookConnMySqlConnectorPython:
         args, kwargs = mock_connect.call_args
         assert args == ()
         assert kwargs["ssl_disabled"] == 1
+
+    @mock.patch("mysql.connector.connect")
+    def test_get_conn_init_command(self, mock_connect):
+        extra_dict = self.connection.extra_dejson
+        self.connection.extra = json.dumps(extra_dict)
+        self.db_hook.init_command = "SET time_zone = '+00:00';"
+        self.db_hook.get_conn()
+        assert mock_connect.call_count == 1
+        args, kwargs = mock_connect.call_args
+        assert args == ()
+        assert kwargs["init_command"] == "SET time_zone = '+00:00';"

Reply via email to