This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 998fcd6cfbc Improve SFTP hook's directory transfer to use a single
connection for multiple files (#46582)
998fcd6cfbc is described below
commit 998fcd6cfbc35b671a07b92d6f6fc532a00bd8dd
Author: Dawnpool <[email protected]>
AuthorDate: Sat Mar 1 21:08:38 2025 +0900
Improve SFTP hook's directory transfer to use a single connection for
multiple files (#46582)
* Improve SFTP directory transfer to use a single connection in multiple
files
* Add with_conn wrapper
* Fix delete_directory
* Delete wrapper and update get_conn
* Add test code
---
.../sftp/src/airflow/providers/sftp/hooks/sftp.py | 69 +++++++++++++++-------
providers/sftp/tests/unit/sftp/hooks/test_sftp.py | 10 +++-
2 files changed, 56 insertions(+), 23 deletions(-)
diff --git a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
index 0b38ffaea9a..dca03b634f2 100644
--- a/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
+++ b/providers/sftp/src/airflow/providers/sftp/hooks/sftp.py
@@ -24,7 +24,7 @@ import os
import stat
import warnings
from collections.abc import Generator, Sequence
-from contextlib import closing, contextmanager
+from contextlib import contextmanager
from fnmatch import fnmatch
from io import BytesIO
from pathlib import Path
@@ -38,6 +38,7 @@ from airflow.hooks.base import BaseHook
from airflow.providers.ssh.hooks.ssh import SSHHook
if TYPE_CHECKING:
+ from paramiko import SSHClient
from paramiko.sftp_attr import SFTPAttributes
from paramiko.sftp_client import SFTPClient
@@ -110,6 +111,10 @@ class SFTPHook(SSHHook):
kwargs["host_proxy_cmd"] = host_proxy_cmd
self.ssh_conn_id = ssh_conn_id
+ self._ssh_conn: SSHClient | None = None
+ self._sftp_conn: SFTPClient | None = None
+ self._conn_count = 0
+
super().__init__(*args, **kwargs)
def get_conn(self) -> SFTPClient: # type: ignore[override]
@@ -127,9 +132,25 @@ class SFTPHook(SSHHook):
@contextmanager
def get_managed_conn(self) -> Generator[SFTPClient, None, None]:
"""Context manager that closes the connection after use."""
- with closing(super().get_conn()) as conn:
- with closing(conn.open_sftp()) as sftp:
- yield sftp
+ if self._sftp_conn is None:
+ ssh_conn: SSHClient = super().get_conn()
+ self._ssh_conn = ssh_conn
+ self._sftp_conn = ssh_conn.open_sftp()
+ self._conn_count += 1
+
+ try:
+ yield self._sftp_conn
+ finally:
+ self._conn_count -= 1
+ if self._conn_count == 0 and self._ssh_conn is not None and
self._sftp_conn is not None:
+ self._sftp_conn.close()
+ self._sftp_conn = None
+ self._ssh_conn.close()
+ self._ssh_conn = None
+
+ def get_conn_count(self) -> int:
+ """Get the number of open connections."""
+ return self._conn_count
def describe_directory(self, path: str) -> dict[str, dict[str, str | int |
None]]:
"""
@@ -309,13 +330,14 @@ class SFTPHook(SSHHook):
if Path(local_full_path).exists():
raise AirflowException(f"{local_full_path} already exists")
Path(local_full_path).mkdir(parents=True)
- files, dirs, _ = self.get_tree_map(remote_full_path)
- for dir_path in dirs:
- new_local_path = os.path.join(local_full_path,
os.path.relpath(dir_path, remote_full_path))
- Path(new_local_path).mkdir(parents=True, exist_ok=True)
- for file_path in files:
- new_local_path = os.path.join(local_full_path,
os.path.relpath(file_path, remote_full_path))
- self.retrieve_file(file_path, new_local_path, prefetch)
+ with self.get_conn():
+ files, dirs, _ = self.get_tree_map(remote_full_path)
+ for dir_path in dirs:
+ new_local_path = os.path.join(local_full_path,
os.path.relpath(dir_path, remote_full_path))
+ Path(new_local_path).mkdir(parents=True, exist_ok=True)
+ for file_path in files:
+ new_local_path = os.path.join(local_full_path,
os.path.relpath(file_path, remote_full_path))
+ self.retrieve_file(file_path, new_local_path, prefetch)
def store_directory(self, remote_full_path: str, local_full_path: str,
confirm: bool = True) -> None:
"""
@@ -329,16 +351,21 @@ class SFTPHook(SSHHook):
"""
if self.path_exists(remote_full_path):
raise AirflowException(f"{remote_full_path} already exists")
- self.create_directory(remote_full_path)
- for root, dirs, files in os.walk(local_full_path):
- for dir_name in dirs:
- dir_path = os.path.join(root, dir_name)
- new_remote_path = os.path.join(remote_full_path,
os.path.relpath(dir_path, local_full_path))
- self.create_directory(new_remote_path)
- for file_name in files:
- file_path = os.path.join(root, file_name)
- new_remote_path = os.path.join(remote_full_path,
os.path.relpath(file_path, local_full_path))
- self.store_file(new_remote_path, file_path, confirm)
+ with self.get_conn():
+ self.create_directory(remote_full_path)
+ for root, dirs, files in os.walk(local_full_path):
+ for dir_name in dirs:
+ dir_path = os.path.join(root, dir_name)
+ new_remote_path = os.path.join(
+ remote_full_path, os.path.relpath(dir_path,
local_full_path)
+ )
+ self.create_directory(new_remote_path)
+ for file_name in files:
+ file_path = os.path.join(root, file_name)
+ new_remote_path = os.path.join(
+ remote_full_path, os.path.relpath(file_path,
local_full_path)
+ )
+ self.store_file(new_remote_path, file_path, confirm)
def get_mod_time(self, path: str) -> str:
"""
diff --git a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
index b1c0d13c369..38239546da8 100644
--- a/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
+++ b/providers/sftp/tests/unit/sftp/hooks/test_sftp.py
@@ -114,8 +114,14 @@ class TestSFTPHook:
assert self.hook.conn is None
def test_get_managed_conn(self):
- with self.hook.get_managed_conn() as conn:
- assert isinstance(conn, paramiko.SFTPClient)
+ with self.hook.get_managed_conn() as conn1:
+ assert isinstance(conn1, paramiko.SFTPClient)
+ with self.hook.get_managed_conn() as conn2:
+ assert conn1 == conn2
+ assert self.hook.get_conn_count() == 2
+ assert self.hook.get_conn_count() == 1
+ assert self.hook.get_conn_count() == 0
+ assert self.hook.conn is None
@patch("airflow.providers.ssh.hooks.ssh.SSHHook.get_conn")
def test_get_close_conn(self, mock_get_conn):