This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 998fcd6cfbc Improve SFTP hook's directory transfer to use a single 
connection for multiple files (#46582)
998fcd6cfbc is described below

commit 998fcd6cfbc35b671a07b92d6f6fc532a00bd8dd
Author: Dawnpool <[email protected]>
AuthorDate: Sat Mar 1 21:08:38 2025 +0900

    Improve SFTP hook's directory transfer to use a single connection for 
multiple files (#46582)
    
    * Improve SFTP directory transfer to use a single connection in multiple 
files
    
    * Add with_conn wrapper
    
    * Fix delete_directory
    
    * Delete wrapper and update get_conn
    
    * Add test code
---
 .../sftp/src/airflow/providers/sftp/hooks/sftp.py  | 69 +++++++++++++++-------
 providers/sftp/tests/unit/sftp/hooks/test_sftp.py  | 10 +++-
 2 files changed, 56 insertions(+), 23 deletions(-)

diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py 
b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
index 0b38ffaea9a..dca03b634f2 100644
--- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
+++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
@@ -24,7 +24,7 @@ import os
 import stat
 import warnings
 from collections.abc import Generator, Sequence
-from contextlib import closing, contextmanager
+from contextlib import contextmanager
 from fnmatch import fnmatch
 from io import BytesIO
 from pathlib import Path
@@ -38,6 +38,7 @@ from airflow.hooks.base import BaseHook
 from airflow.providers.ssh.hooks.ssh import SSHHook
 
 if TYPE_CHECKING:
+    from paramiko import SSHClient
     from paramiko.sftp_attr import SFTPAttributes
     from paramiko.sftp_client import SFTPClient
 
@@ -110,6 +111,10 @@ class SFTPHook(SSHHook):
         kwargs["host_proxy_cmd"] = host_proxy_cmd
         self.ssh_conn_id = ssh_conn_id
 
+        self._ssh_conn: SSHClient | None = None
+        self._sftp_conn: SFTPClient | None = None
+        self._conn_count = 0
+
         super().__init__(*args, **kwargs)
 
     def get_conn(self) -> SFTPClient:  # type: ignore[override]
@@ -127,9 +132,25 @@ class SFTPHook(SSHHook):
     @contextmanager
     def get_managed_conn(self) -> Generator[SFTPClient, None, None]:
         """Context manager that closes the connection after use."""
-        with closing(super().get_conn()) as conn:
-            with closing(conn.open_sftp()) as sftp:
-                yield sftp
+        if self._sftp_conn is None:
+            ssh_conn: SSHClient = super().get_conn()
+            self._ssh_conn = ssh_conn
+            self._sftp_conn = ssh_conn.open_sftp()
+        self._conn_count += 1
+
+        try:
+            yield self._sftp_conn
+        finally:
+            self._conn_count -= 1
+            if self._conn_count == 0 and self._ssh_conn is not None and 
self._sftp_conn is not None:
+                self._sftp_conn.close()
+                self._sftp_conn = None
+                self._ssh_conn.close()
+                self._ssh_conn = None
+
+    def get_conn_count(self) -> int:
+        """Get the number of open connections."""
+        return self._conn_count
 
     def describe_directory(self, path: str) -> dict[str, dict[str, str | int | 
None]]:
         """
@@ -309,13 +330,14 @@ class SFTPHook(SSHHook):
         if Path(local_full_path).exists():
             raise AirflowException(f"{local_full_path} already exists")
         Path(local_full_path).mkdir(parents=True)
-        files, dirs, _ = self.get_tree_map(remote_full_path)
-        for dir_path in dirs:
-            new_local_path = os.path.join(local_full_path, 
os.path.relpath(dir_path, remote_full_path))
-            Path(new_local_path).mkdir(parents=True, exist_ok=True)
-        for file_path in files:
-            new_local_path = os.path.join(local_full_path, 
os.path.relpath(file_path, remote_full_path))
-            self.retrieve_file(file_path, new_local_path, prefetch)
+        with self.get_conn():
+            files, dirs, _ = self.get_tree_map(remote_full_path)
+            for dir_path in dirs:
+                new_local_path = os.path.join(local_full_path, 
os.path.relpath(dir_path, remote_full_path))
+                Path(new_local_path).mkdir(parents=True, exist_ok=True)
+            for file_path in files:
+                new_local_path = os.path.join(local_full_path, 
os.path.relpath(file_path, remote_full_path))
+                self.retrieve_file(file_path, new_local_path, prefetch)
 
     def store_directory(self, remote_full_path: str, local_full_path: str, 
confirm: bool = True) -> None:
         """
@@ -329,16 +351,21 @@ class SFTPHook(SSHHook):
         """
         if self.path_exists(remote_full_path):
             raise AirflowException(f"{remote_full_path} already exists")
-        self.create_directory(remote_full_path)
-        for root, dirs, files in os.walk(local_full_path):
-            for dir_name in dirs:
-                dir_path = os.path.join(root, dir_name)
-                new_remote_path = os.path.join(remote_full_path, 
os.path.relpath(dir_path, local_full_path))
-                self.create_directory(new_remote_path)
-            for file_name in files:
-                file_path = os.path.join(root, file_name)
-                new_remote_path = os.path.join(remote_full_path, 
os.path.relpath(file_path, local_full_path))
-                self.store_file(new_remote_path, file_path, confirm)
+        with self.get_conn():
+            self.create_directory(remote_full_path)
+            for root, dirs, files in os.walk(local_full_path):
+                for dir_name in dirs:
+                    dir_path = os.path.join(root, dir_name)
+                    new_remote_path = os.path.join(
+                        remote_full_path, os.path.relpath(dir_path, 
local_full_path)
+                    )
+                    self.create_directory(new_remote_path)
+                for file_name in files:
+                    file_path = os.path.join(root, file_name)
+                    new_remote_path = os.path.join(
+                        remote_full_path, os.path.relpath(file_path, 
local_full_path)
+                    )
+                    self.store_file(new_remote_path, file_path, confirm)
 
     def get_mod_time(self, path: str) -> str:
         """
diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py 
b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
index b1c0d13c369..38239546da8 100644
--- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
+++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
@@ -114,8 +114,14 @@ class TestSFTPHook:
         assert self.hook.conn is None
 
     def test_get_managed_conn(self):
-        with self.hook.get_managed_conn() as conn:
-            assert isinstance(conn, paramiko.SFTPClient)
+        with self.hook.get_managed_conn() as conn1:
+            assert isinstance(conn1, paramiko.SFTPClient)
+            with self.hook.get_managed_conn() as conn2:
+                assert conn1 == conn2
+                assert self.hook.get_conn_count() == 2
+            assert self.hook.get_conn_count() == 1
+        assert self.hook.get_conn_count() == 0
+        assert self.hook.conn is None
 
     @patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn")
     def test_get_close_conn(self, mock_get_conn):

Reply via email to