This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2f26da7023 feat: load host keys to save new host key (#25979)
2f26da7023 is described below

commit 2f26da70230d7d1cf7dfb3a20d38e9a5844862a7
Author: doiken <[email protected]>
AuthorDate: Sat Aug 27 11:02:52 2022 +0900

    feat: load host keys to save new host key (#25979)
    
    Co-authored-by: doiken <[email protected]>
---
 airflow/providers/ssh/hooks/ssh.py    | 11 ++++---
 tests/providers/ssh/hooks/test_ssh.py | 55 +++++++++++++++++++++++++++++++++++
 2 files changed, 62 insertions(+), 4 deletions(-)

diff --git a/airflow/providers/ssh/hooks/ssh.py 
b/airflow/providers/ssh/hooks/ssh.py
index 412842daa8..17545b1755 100644
--- a/airflow/providers/ssh/hooks/ssh.py
+++ b/airflow/providers/ssh/hooks/ssh.py
@@ -280,11 +280,18 @@ class SSHHook(BaseHook):
                 "Remote Identification Change is not verified. "
                 "This won't protect against Man-In-The-Middle attacks"
             )
+            # to avoid BadHostKeyException, skip loading host keys
+            client.set_missing_host_key_policy(paramiko.MissingHostKeyPolicy)
         else:
             client.load_system_host_keys()
 
         if self.no_host_key_check:
             self.log.warning("No Host Key Verification. This won't protect 
against Man-In-The-Middle attacks")
+            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
+            # to avoid BadHostKeyException, skip loading and saving host keys
+            known_hosts = os.path.expanduser("~/.ssh/known_hosts")
+            if not self.allow_host_key_change and os.path.isfile(known_hosts):
+                client.load_host_keys(known_hosts)
         else:
             if self.host_key is not None:
                 client_host_keys = client.get_host_keys()
@@ -297,10 +304,6 @@ class SSHHook(BaseHook):
             else:
                 pass  # will fallback to system host keys if none explicitly 
specified in conn extra
 
-        if self.no_host_key_check or self.allow_host_key_change:
-            # Default is RejectPolicy
-            client.set_missing_host_key_policy(paramiko.AutoAddPolicy())
-
         connect_kwargs: Dict[str, Any] = dict(
             hostname=self.remote_host,
             username=self.username,
diff --git a/tests/providers/ssh/hooks/test_ssh.py 
b/tests/providers/ssh/hooks/test_ssh.py
index e11482336b..7afe855834 100644
--- a/tests/providers/ssh/hooks/test_ssh.py
+++ b/tests/providers/ssh/hooks/test_ssh.py
@@ -102,6 +102,12 @@ class TestSSHHook(unittest.TestCase):
     )
     CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS = 
'ssh_with_extra_disabled_algorithms'
     CONN_SSH_WITH_EXTRA_CIPHERS = 'ssh_with_extra_ciphers'
+    CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_TRUE = (
+        'ssh_with_no_host_key_check_true_and_allow_host_key_changes_true'
+    )
+    CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_FALSE = (
+        'ssh_with_no_host_key_check_true_and_allow_host_key_changes_false'
+    )
 
     @classmethod
     def tearDownClass(cls) -> None:
@@ -123,6 +129,8 @@ class TestSSHHook(unittest.TestCase):
                 cls.CONN_SSH_WITH_NO_HOST_KEY_AND_NO_HOST_KEY_CHECK_TRUE,
                 cls.CONN_SSH_WITH_EXTRA_DISABLED_ALGORITHMS,
                 cls.CONN_SSH_WITH_EXTRA_CIPHERS,
+                
cls.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_TRUE,
+                
cls.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_FALSE,
             ]
             connections = 
session.query(Connection).filter(Connection.conn_id.in_(conns_to_reset))
             connections.delete(synchronize_session=False)
@@ -287,6 +295,22 @@ class TestSSHHook(unittest.TestCase):
                 extra=json.dumps({"ciphers": TEST_CIPHERS}),
             )
         )
+        db.merge_conn(
+            Connection(
+                
conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_TRUE,
+                host='remote_host',
+                conn_type='ssh',
+                extra=json.dumps({"no_host_key_check": True, 
"allow_host_key_change": True}),
+            )
+        )
+        db.merge_conn(
+            Connection(
+                
conn_id=cls.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_FALSE,
+                host='remote_host',
+                conn_type='ssh',
+                extra=json.dumps({"no_host_key_check": True, 
"allow_host_key_change": False}),
+            )
+        )
 
     @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
     def test_ssh_connection_with_password(self, ssh_mock):
@@ -881,3 +905,34 @@ class TestSSHHook(unittest.TestCase):
                 30,
             )
             assert ret == (0, b'airflow\n', b'')
+
+    @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+    def 
test_ssh_connection_with_no_host_key_check_true_and_allow_host_key_changes_true(self,
 ssh_mock):
+        hook = 
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_TRUE)
+        with hook.get_conn():
+            assert ssh_mock.return_value.set_missing_host_key_policy.called is 
True
+            assert isinstance(
+                
ssh_mock.return_value.set_missing_host_key_policy.call_args[0][0], 
paramiko.AutoAddPolicy
+            )
+            assert ssh_mock.return_value.load_host_keys.called is False
+
+    @mock.patch('airflow.providers.ssh.hooks.ssh.paramiko.SSHClient')
+    def 
test_ssh_connection_with_no_host_key_check_true_and_allow_host_key_changes_false(self,
 ssh_mock):
+        hook = 
SSHHook(ssh_conn_id=self.CONN_SSH_WITH_NO_HOST_KEY_CHECK_TRUE_AND_ALLOW_HOST_KEY_CHANGES_FALSE)
+
+        with mock.patch('os.path.isfile', return_value=True):
+            with hook.get_conn():
+                assert 
ssh_mock.return_value.set_missing_host_key_policy.called is True
+                assert isinstance(
+                    
ssh_mock.return_value.set_missing_host_key_policy.call_args[0][0], 
paramiko.AutoAddPolicy
+                )
+                assert ssh_mock.return_value.load_host_keys.called is True
+
+        ssh_mock.reset_mock()
+        with mock.patch('os.path.isfile', return_value=False):
+            with hook.get_conn():
+                assert 
ssh_mock.return_value.set_missing_host_key_policy.called is True
+                assert isinstance(
+                    
ssh_mock.return_value.set_missing_host_key_policy.call_args[0][0], 
paramiko.AutoAddPolicy
+                )
+                assert ssh_mock.return_value.load_host_keys.called is False

Reply via email to