Copilot commented on code in PR #64465: URL: https://github.com/apache/airflow/pull/64465#discussion_r3025335348
########## providers/sftp/tests/conftest.py: ########## @@ -16,4 +16,36 @@ # under the License. from __future__ import annotations +from collections.abc import Generator +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, patch + +import pytest +from asyncssh import SFTPClient, SSHClientConnection + +if TYPE_CHECKING: + from airflow.providers.sftp.hooks.sftp import SFTPHookAsync + pytest_plugins = "tests_common.pytest_plugin" + + [email protected] +def sftp_hook_mocked() -> Generator[tuple[SFTPHookAsync, SFTPClient], Any, None]: + """ + Fixture that mocks SFTPHookAsync._get_conn with SSH + SFTP async mocks. + Returns a tuple (hook, sftp_client_mock) so tests can easily set readdir. + """ + from airflow.providers.sftp.hooks.sftp import SFTPHookAsync + + sftp_client_mock = AsyncMock(spec=SFTPClient) + sftp_client_mock.readdir.return_value = [] + + client_connection_mock = AsyncMock(spec=SSHClientConnection) + sftp_cm_mock = client_connection_mock.start_sftp_client.return_value + sftp_cm_mock.__aenter__ = AsyncMock(return_value=sftp_client_mock) + sftp_cm_mock.__aexit__ = AsyncMock(return_value=None) + + with patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") as mock_get_conn: + mock_get_conn.return_value.__aenter__.return_value = client_connection_mock + + yield SFTPHookAsync(), sftp_cm_mock Review Comment: This fixture is annotated as returning `(SFTPHookAsync, SFTPClient)` but it actually yields `(SFTPHookAsync(), sftp_cm_mock)` where the second element is the async context manager returned by `start_sftp_client()`, not the `SFTPClient` itself. This makes the typing/docs misleading and makes test usage confusing. Consider returning both the connection and the actual `SFTPClient` mock (or rename/retag the tuple appropriately). ########## providers/sftp/src/airflow/providers/sftp/hooks/sftp.py: ########## @@ -789,24 +793,117 @@ def _get_value(self_val, conn_val, default=None): ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn - async def list_directory(self, path: str = "") -> list[str] | None: # type: ignore[return] - """Return a list of files on the SFTP server at the provided path.""" + async def retrieve_file( + self, + remote_full_path: str, + local_full_path: str | BytesIO, + encoding: str = "utf-8", + chunk_size: int = CHUNK_SIZE, + ) -> None: + """ + Transfer the remote file to a local location asynchronously. + + If local_full_path is a string path, the file will be put at that location. + If it is a BytesIO or file-like object, the file will be streamed into it. + + :param remote_full_path: Full path to the remote file. + :param local_full_path: Full path to the local file or a file-like buffer. + :param encoding: Encoding to use for reading the remote file (default: "utf-8"). + :param chunk_size: Size of chunks to read at a time (default: 64KB). + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + async with sftp.open(remote_full_path, encoding=encoding) as remote_file: + if isinstance(local_full_path, BytesIO): + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + local_full_path.write(chunk.encode(encoding)) + local_full_path.seek(0) + else: + async with aiofiles.open(local_full_path, "wb") as f: + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + await f.write(chunk) Review Comment: `retrieve_file()` opens the remote file with `encoding=...`, which makes `remote_file.read()` return `str`, but the implementation writes chunks to a binary local file (`aiofiles.open(..., "wb")`) and also sometimes treats chunks as `bytes`. This will raise type errors depending on the backend. Consider reading bytes (open remote in binary mode / without `encoding` and write bytes), or if you want text mode, open the local file in text mode and avoid mixing `str`/`bytes` (and avoid `.encode()` when chunks are already bytes). ########## providers/sftp/tests/unit/sftp/pools/test_sftp.py: ########## @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from airflow.providers.sftp.pools.sftp import SFTPClientPool + + [email protected] +class TestSFTPClientPool: + @pytest.mark.asyncio + async def test_acquire_and_release(self, sftp_hook_mocked): + async with SFTPClientPool("test_conn", pool_size=2) as pool: + ssh, sftp = await pool.acquire() + assert ssh is not None + assert sftp is not None + + await pool.release((ssh, sftp)) + ssh2, sftp2 = await pool.acquire() Review Comment: Because `SFTPClientPool` is a process-wide singleton (`_instances`), using the same conn_id (`"test_conn"`) across multiple tests can cause state leakage/order dependence (e.g. `_closed`, `_idle`, `_in_use`). Consider resetting `SFTPClientPool._instances` in a fixture/teardown or use unique conn_ids per test to keep tests isolated. ########## providers/sftp/src/airflow/providers/sftp/hooks/sftp.py: ########## @@ -789,24 +793,117 @@ def _get_value(self_val, conn_val, default=None): ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn - async def list_directory(self, path: str = "") -> list[str] | None: # type: ignore[return] - """Return a list of files on the SFTP server at the provided path.""" + async def retrieve_file( + self, + remote_full_path: str, + local_full_path: str | BytesIO, + encoding: str = "utf-8", + chunk_size: int = CHUNK_SIZE, + ) -> None: + """ + Transfer the remote file to a local location asynchronously. + + If local_full_path is a string path, the file will be put at that location. + If it is a BytesIO or file-like object, the file will be streamed into it. + + :param remote_full_path: Full path to the remote file. + :param local_full_path: Full path to the local file or a file-like buffer. + :param encoding: Encoding to use for reading the remote file (default: "utf-8"). + :param chunk_size: Size of chunks to read at a time (default: 64KB). + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + async with sftp.open(remote_full_path, encoding=encoding) as remote_file: + if isinstance(local_full_path, BytesIO): + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + local_full_path.write(chunk.encode(encoding)) + local_full_path.seek(0) + else: + async with aiofiles.open(local_full_path, "wb") as f: + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + await f.write(chunk) Review Comment: The `retrieve_file()` docstring says `local_full_path` can be a "BytesIO or file-like buffer", but the type annotation and implementation only handle `BytesIO` (and otherwise assume a filesystem path). Either update the docstring/type hints to match the supported inputs, or add support for generic writable binary streams (e.g. objects with a `.write()` method). ```suggestion local_full_path: str | os.PathLike[str] | IO[bytes], encoding: str = "utf-8", chunk_size: int = CHUNK_SIZE, ) -> None: """ Transfer the remote file to a local location asynchronously. If local_full_path is a string or PathLike path, the file will be put at that location. If it is a BytesIO or other binary file-like object, the file will be streamed into it. :param remote_full_path: Full path to the remote file. :param local_full_path: Full path to the local file or a binary file-like buffer. :param encoding: Encoding to use for reading the remote file (default: "utf-8"). :param chunk_size: Size of chunks to read at a time (default: 64KB). """ async with await self._get_conn() as ssh_conn: async with ssh_conn.start_sftp_client() as sftp: async with sftp.open(remote_full_path, encoding=encoding) as remote_file: if isinstance(local_full_path, (str, os.PathLike)): async with aiofiles.open(local_full_path, "wb") as f: while True: chunk = await remote_file.read(chunk_size) if not chunk: break await f.write(chunk.encode(encoding)) else: while True: chunk = await remote_file.read(chunk_size) if not chunk: break local_full_path.write(chunk.encode(encoding)) if hasattr(local_full_path, "seek"): local_full_path.seek(0) ``` ########## providers/sftp/tests/unit/sftp/hooks/test_sftp.py: ########## @@ -947,86 +948,177 @@ async def test_init_argument_not_ignored(self, mock_get_connection, mock_connect ), ] - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_list_directory_path_does_not_exist(self, mock_hook_get_conn): + async def test_list_directory_path_does_not_exist(self, sftp_hook_mocked): """ Assert that AirflowException is raised when path does not exist on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - - hook = SFTPHookAsync() + hook, sftp_cm_mock = sftp_hook_mocked - expected_files = None files = await hook.list_directory(path="/path/does_not/exist/") - assert files == expected_files - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert not files + sftp_cm_mock.__aexit__.assert_awaited() - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_read_directory_path_does_not_exist(self, mock_hook_get_conn): + async def test_read_directory_path_does_not_exist(self, sftp_hook_mocked): """ Assert that AirflowException is raised when path does not exist on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked - expected_files = None files = await hook.read_directory(path="/path/does_not/exist/") - assert files == expected_files - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert not files + sftp_client_mock.__aexit__.assert_awaited() - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_list_directory_path_has_files(self, mock_hook_get_conn): + async def test_list_directory_path_has_files(self, sftp_hook_mocked): """ Assert that file list is returned when path exists on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client_mock.__aenter__.return_value.readdir.return_value = [ + Mock(spec=SFTPName, filename="..", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename=".", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename="file", attrs=Mock(permissions=0)), + ] - expected_files = ["..", ".", "file"] files = await hook.list_directory(path="/path/exists/") - assert sorted(files) == sorted(expected_files) - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert sorted(files) == sorted(["/path/exists/file"]) + sftp_client_mock.__aexit__.assert_awaited() - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_get_file_by_pattern_with_match(self, mock_hook_get_conn): + async def test_get_file_by_pattern_with_match(self, sftp_hook_mocked): """ Assert that filename is returned when file pattern is matched on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client_mock.__aenter__.return_value.readdir.return_value = [ + Mock(spec=SFTPName, filename="..", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename=".", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename="file", attrs=Mock(permissions=0)), + ] files = await hook.get_files_and_attrs_by_pattern(path="/path/exists/", fnmatch_pattern="file") assert len(files) == 1 assert files[0].filename == "file" - mock_hook_get_conn.return_value.__aexit__.assert_called() + sftp_client_mock.__aexit__.assert_awaited() @pytest.mark.asyncio - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") - async def test_get_mod_time(self, mock_hook_get_conn): + async def test_get_mod_time(self, sftp_hook_mocked): """ Assert that file attribute and return the modified time of the file """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + mtime = 1667302566 # This is a valid Unix timestamp + expected = datetime.datetime.fromtimestamp(mtime).strftime("%Y%m%d%H%M%S") + sftp_client_mock.__aenter__.return_value.stat.return_value = Mock(spec=SFTPAttrs, mtime=mtime) mod_time = await hook.get_mod_time("/path/exists/file") - expected_value = datetime.datetime.fromtimestamp(1667302566).strftime("%Y%m%d%H%M%S") - assert mod_time == expected_value + assert mod_time == expected @pytest.mark.asyncio - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") - async def test_get_mod_time_exception(self, mock_hook_get_conn): + async def test_get_mod_time_exception(self, sftp_hook_mocked): """ Assert that get_mod_time raise exception when file does not exist """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client_mock.__aenter__.return_value.stat.side_effect = SFTPNoSuchFile( + reason="File does not exist" + ) with pytest.raises(AirflowException) as exc: await hook.get_mod_time("/path/does_not/exist/") assert str(exc.value) == "No files matching" + + @pytest.mark.asyncio + async def test_mkdir_creates_directory(self, sftp_hook_mocked): + """ + Assert that mkdir calls makedirs on the SFTP client + """ + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client = sftp_client_mock.__aenter__.return_value + sftp_client.makedirs = AsyncMock() + + await hook.mkdir("/remote/newdir") + sftp_client.makedirs.assert_awaited_once_with("/remote/newdir") + sftp_client_mock.__aexit__.assert_awaited() + + @pytest.mark.asyncio + @patch("aiofiles.open", new_callable=MagicMock) + async def test_retrieve_file_to_path(self, mock_aiofiles_open, sftp_hook_mocked): + """ + Assert that retrieve_file writes to a local file using aiofiles + """ + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client = sftp_client_mock.__aenter__.return_value + mock_remote_file = AsyncMock() + mock_remote_file.read = AsyncMock(side_effect=[b"abc", b"", StopAsyncIteration]) + sftp_client.open.return_value.__aenter__.return_value = mock_remote_file + mock_file = AsyncMock() + mock_aiofiles_open.return_value.__aenter__.return_value = mock_file + + await hook.retrieve_file("/remote/file", "/local/file") + sftp_client.open.assert_called_once_with("/remote/file", encoding="utf-8") + mock_file.write.assert_awaited() + sftp_client_mock.__aexit__.assert_awaited() + + @pytest.mark.asyncio + async def test_retrieve_file_to_bytesio(self, sftp_hook_mocked): + """ + Assert that retrieve_file writes to a BytesIO buffer + """ + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client = sftp_client_mock.__aenter__.return_value + mock_remote_file = AsyncMock() + mock_remote_file.read = AsyncMock(side_effect=["abc", ""]) + sftp_client.open.return_value.__aenter__.return_value = mock_remote_file + buf = BytesIO() + + await hook.retrieve_file("/remote/file", buf) + assert buf.getvalue() == b"abc" + sftp_client.open.assert_called_once_with("/remote/file", encoding="utf-8") + sftp_client_mock.__aexit__.assert_awaited() Review Comment: The `retrieve_file` tests mock `remote_file.read()` to return `bytes` in one case and `str` in another, but the implementation currently opens the remote file with `encoding="utf-8"` which will typically make reads return `str`. Once the production code is fixed to consistently read bytes or text, these tests should be updated to match that contract (and assert exact written bytes/strings). ########## providers/sftp/src/airflow/providers/sftp/pools/sftp.py: ########## @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager, suppress +from threading import Lock + +import asyncssh + +from airflow.configuration import conf +from airflow.providers.sftp.hooks.sftp import SFTPHookAsync +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin + + +class SFTPClientPool(LoggingMixin): + """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and SFTP clients alive until exit, and limits concurrent usage to pool_size.""" + + _instances: dict[str, SFTPClientPool] = {} + _lock = Lock() + + def __new__(cls, sftp_conn_id: str, pool_size: int = None): + with cls._lock: + if sftp_conn_id not in cls._instances: + instance = super().__new__(cls) + instance._pre_init(sftp_conn_id, pool_size) + cls._instances[sftp_conn_id] = instance Review Comment: `SFTPClientPool` is a singleton keyed only by `sftp_conn_id`. If the pool is first constructed with one `pool_size` and later with a different `pool_size`, the later value is silently ignored, which can lead to very surprising runtime behavior. Consider including `pool_size` in the cache key, or validate and raise when a different `pool_size` is requested for an existing instance. ```suggestion cls._instances[sftp_conn_id] = instance else: # Validate that subsequent constructions for the same sftp_conn_id # do not request a different pool_size, which would otherwise be # silently ignored due to the singleton behavior. instance = cls._instances[sftp_conn_id] requested_pool_size = pool_size or conf.getint("core", "parallelism") if instance.pool_size != requested_pool_size: raise ValueError( f"SFTPClientPool for sftp_conn_id '{sftp_conn_id}' has already been " f"initialised with pool_size={instance.pool_size}, but a different " f"pool_size={requested_pool_size} was requested." ) ``` ########## providers/sftp/tests/unit/sftp/hooks/test_sftp.py: ########## @@ -947,86 +948,177 @@ async def test_init_argument_not_ignored(self, mock_get_connection, mock_connect ), ] - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_list_directory_path_does_not_exist(self, mock_hook_get_conn): + async def test_list_directory_path_does_not_exist(self, sftp_hook_mocked): """ Assert that AirflowException is raised when path does not exist on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - - hook = SFTPHookAsync() + hook, sftp_cm_mock = sftp_hook_mocked - expected_files = None files = await hook.list_directory(path="/path/does_not/exist/") - assert files == expected_files - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert not files + sftp_cm_mock.__aexit__.assert_awaited() - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_read_directory_path_does_not_exist(self, mock_hook_get_conn): + async def test_read_directory_path_does_not_exist(self, sftp_hook_mocked): """ Assert that AirflowException is raised when path does not exist on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked - expected_files = None files = await hook.read_directory(path="/path/does_not/exist/") - assert files == expected_files - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert not files + sftp_client_mock.__aexit__.assert_awaited() - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_list_directory_path_has_files(self, mock_hook_get_conn): + async def test_list_directory_path_has_files(self, sftp_hook_mocked): """ Assert that file list is returned when path exists on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client_mock.__aenter__.return_value.readdir.return_value = [ + Mock(spec=SFTPName, filename="..", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename=".", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename="file", attrs=Mock(permissions=0)), + ] - expected_files = ["..", ".", "file"] files = await hook.list_directory(path="/path/exists/") - assert sorted(files) == sorted(expected_files) - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert sorted(files) == sorted(["/path/exists/file"]) + sftp_client_mock.__aexit__.assert_awaited() - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_get_file_by_pattern_with_match(self, mock_hook_get_conn): + async def test_get_file_by_pattern_with_match(self, sftp_hook_mocked): """ Assert that filename is returned when file pattern is matched on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client_mock.__aenter__.return_value.readdir.return_value = [ + Mock(spec=SFTPName, filename="..", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename=".", attrs=Mock(permissions=0)), + Mock(spec=SFTPName, filename="file", attrs=Mock(permissions=0)), + ] files = await hook.get_files_and_attrs_by_pattern(path="/path/exists/", fnmatch_pattern="file") assert len(files) == 1 assert files[0].filename == "file" - mock_hook_get_conn.return_value.__aexit__.assert_called() + sftp_client_mock.__aexit__.assert_awaited() @pytest.mark.asyncio - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") - async def test_get_mod_time(self, mock_hook_get_conn): + async def test_get_mod_time(self, sftp_hook_mocked): """ Assert that file attribute and return the modified time of the file """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + mtime = 1667302566 # This is a valid Unix timestamp + expected = datetime.datetime.fromtimestamp(mtime).strftime("%Y%m%d%H%M%S") + sftp_client_mock.__aenter__.return_value.stat.return_value = Mock(spec=SFTPAttrs, mtime=mtime) mod_time = await hook.get_mod_time("/path/exists/file") - expected_value = datetime.datetime.fromtimestamp(1667302566).strftime("%Y%m%d%H%M%S") - assert mod_time == expected_value + assert mod_time == expected @pytest.mark.asyncio - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") - async def test_get_mod_time_exception(self, mock_hook_get_conn): + async def test_get_mod_time_exception(self, sftp_hook_mocked): """ Assert that get_mod_time raise exception when file does not exist """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - hook = SFTPHookAsync() + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client_mock.__aenter__.return_value.stat.side_effect = SFTPNoSuchFile( + reason="File does not exist" + ) with pytest.raises(AirflowException) as exc: await hook.get_mod_time("/path/does_not/exist/") assert str(exc.value) == "No files matching" + + @pytest.mark.asyncio + async def test_mkdir_creates_directory(self, sftp_hook_mocked): + """ + Assert that mkdir calls makedirs on the SFTP client + """ + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client = sftp_client_mock.__aenter__.return_value + sftp_client.makedirs = AsyncMock() + + await hook.mkdir("/remote/newdir") + sftp_client.makedirs.assert_awaited_once_with("/remote/newdir") + sftp_client_mock.__aexit__.assert_awaited() + + @pytest.mark.asyncio + @patch("aiofiles.open", new_callable=MagicMock) + async def test_retrieve_file_to_path(self, mock_aiofiles_open, sftp_hook_mocked): + """ + Assert that retrieve_file writes to a local file using aiofiles + """ + hook, sftp_client_mock = sftp_hook_mocked + + sftp_client = sftp_client_mock.__aenter__.return_value + mock_remote_file = AsyncMock() + mock_remote_file.read = AsyncMock(side_effect=[b"abc", b"", StopAsyncIteration]) + sftp_client.open.return_value.__aenter__.return_value = mock_remote_file + mock_file = AsyncMock() + mock_aiofiles_open.return_value.__aenter__.return_value = mock_file + Review Comment: `aiofiles.open` is used as an async context manager in `retrieve_file()`, but this test patches it with `new_callable=MagicMock`, which won't provide awaitable `__aenter__/__aexit__` methods and will fail under `async with`. Patch `aiofiles.open` to return an object whose `__aenter__`/`__aexit__` are `AsyncMock`s (see patterns used in e.g. `providers/databricks/tests/...`). ########## providers/sftp/tests/unit/sftp/hooks/test_sftp.py: ########## @@ -947,86 +948,177 @@ async def test_init_argument_not_ignored(self, mock_get_connection, mock_connect ), ] - @patch("airflow.providers.sftp.hooks.sftp.SFTPHookAsync._get_conn") @pytest.mark.asyncio - async def test_list_directory_path_does_not_exist(self, mock_hook_get_conn): + async def test_list_directory_path_does_not_exist(self, sftp_hook_mocked): """ Assert that AirflowException is raised when path does not exist on SFTP server """ - mock_hook_get_conn.return_value.__aenter__.return_value = MockSSHClient() - - hook = SFTPHookAsync() + hook, sftp_cm_mock = sftp_hook_mocked - expected_files = None files = await hook.list_directory(path="/path/does_not/exist/") - assert files == expected_files - mock_hook_get_conn.return_value.__aexit__.assert_called() + assert not files + sftp_cm_mock.__aexit__.assert_awaited() Review Comment: `test_list_directory_path_does_not_exist` currently relies on the default `readdir.return_value = []`, so it doesn't verify the hook's behavior when the directory truly does not exist (i.e. `asyncssh.SFTPNoSuchFile`). Consider setting `readdir.side_effect = asyncssh.SFTPNoSuchFile(...)` and asserting the expected return (`None`) or raised exception, so the test actually covers the error path. ########## providers/sftp/src/airflow/providers/sftp/hooks/sftp.py: ########## @@ -789,24 +793,117 @@ def _get_value(self_val, conn_val, default=None): ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn - async def list_directory(self, path: str = "") -> list[str] | None: # type: ignore[return] - """Return a list of files on the SFTP server at the provided path.""" + async def retrieve_file( + self, + remote_full_path: str, + local_full_path: str | BytesIO, + encoding: str = "utf-8", + chunk_size: int = CHUNK_SIZE, + ) -> None: + """ + Transfer the remote file to a local location asynchronously. + + If local_full_path is a string path, the file will be put at that location. + If it is a BytesIO or file-like object, the file will be streamed into it. + + :param remote_full_path: Full path to the remote file. + :param local_full_path: Full path to the local file or a file-like buffer. + :param encoding: Encoding to use for reading the remote file (default: "utf-8"). + :param chunk_size: Size of chunks to read at a time (default: 64KB). + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + async with sftp.open(remote_full_path, encoding=encoding) as remote_file: + if isinstance(local_full_path, BytesIO): + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + local_full_path.write(chunk.encode(encoding)) + local_full_path.seek(0) + else: + async with aiofiles.open(local_full_path, "wb") as f: + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + await f.write(chunk) + + async def store_file(self, remote_full_path: str, local_full_path: str | bytes | BytesIO) -> None: + """ + Transfer a local file to the remote location. + + If local_full_path_or_buffer is a string path, the file will be read + from that location. + + :param remote_full_path: full path to the remote file + :param local_full_path: full path to the local file or a file-like buffer + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + if isinstance(local_full_path, bytes): + local_full_path = BytesIO(local_full_path) + + if isinstance(local_full_path, BytesIO): + with suppress(asyncssh.SFTPFailure): + remote_path = PurePosixPath(remote_full_path) + await sftp.makedirs(str(remote_path.parent)) + + async with sftp.open(remote_full_path, "wb") as f: + local_full_path.seek(0) + data = local_full_path.read() + await f.write(data) + else: + await sftp.put(str(local_full_path), remote_full_path) + + async def mkdir(self, path: str) -> None: + """ + Create a directory on the remote system asynchronously. + + The default permissions are determined by the server. Parent directories are created as needed. + + :param path: Full path to the remote directory to create. + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + await sftp.makedirs(path) + + async def list_directory(self, path: str) -> list[str] | None: + """ + List files in a directory on the remote system asynchronously. + + Recursively lists all files under the given directory path. + + :param path: Full path to the remote directory to list. + :return: List of file paths found under the directory, or None if the directory does not exist. + """ async with await self._get_conn() as ssh_conn: - sftp_client = await ssh_conn.start_sftp_client() - try: - files = await sftp_client.listdir(path) - return sorted(files) - except asyncssh.SFTPNoSuchFile: - return None + async with ssh_conn.start_sftp_client() as sftp: + + async def walk(dir_path: str): + results = [] + files = await sftp.readdir(dir_path) + + for file in files: + if file.filename not in {".", ".."}: + file_path = posixpath.join(dir_path, file.filename) + if stat.S_ISDIR(file.attrs.permissions): + results.extend(await walk(file_path)) + else: + results.append(file_path) + + return results + + return await walk(path) Review Comment: `list_directory()` now recursively walks subdirectories and returns full paths. This is a breaking behavior change compared to the previous async implementation (which used `listdir()` to list a single directory) and also differs from the sync `SFTPHook.list_directory()` which returns filenames for one level. If recursion/full-paths are desired, consider adding an explicit `recursive` flag or a separate method to preserve the existing API contract. ########## providers/sftp/src/airflow/providers/sftp/pools/sftp.py: ########## @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager, suppress +from threading import Lock + +import asyncssh + +from airflow.configuration import conf +from airflow.providers.sftp.hooks.sftp import SFTPHookAsync +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin Review Comment: This provider code imports `LoggingMixin` from an internal task-sdk module (`airflow.sdk.definitions._internal...`). Providers generally use the public `airflow.utils.log.logging_mixin.LoggingMixin` (e.g. `providers/amazon/.../base_aws.py:61`), and relying on `_internal` paths is brittle. Please switch to the public import to avoid breakage across Airflow versions. ```suggestion from airflow.utils.log.logging_mixin import LoggingMixin ``` ########## providers/sftp/src/airflow/providers/sftp/pools/sftp.py: ########## @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager, suppress +from threading import Lock + +import asyncssh + +from airflow.configuration import conf +from airflow.providers.sftp.hooks.sftp import SFTPHookAsync +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin + + +class SFTPClientPool(LoggingMixin): + """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and SFTP clients alive until exit, and limits concurrent usage to pool_size.""" + + _instances: dict[str, SFTPClientPool] = {} + _lock = Lock() + + def __new__(cls, sftp_conn_id: str, pool_size: int = None): + with cls._lock: + if sftp_conn_id not in cls._instances: + instance = super().__new__(cls) + instance._pre_init(sftp_conn_id, pool_size) + cls._instances[sftp_conn_id] = instance + return cls._instances[sftp_conn_id] + + def __init__(self, sftp_conn_id: str, pool_size: int = None): + # Prevent parent __init__ argument errors + pass + + def _pre_init(self, sftp_conn_id: str, pool_size: int): + """Synchronous initialization for the Singleton structure.""" + LoggingMixin.__init__(self) + self.sftp_conn_id = sftp_conn_id + self.pool_size = pool_size or conf.getint("core", "parallelism") + self._idle: asyncio.LifoQueue[tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]] = ( + asyncio.LifoQueue() + ) + self._in_use: set[tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]] = set() + self._semaphore = asyncio.Semaphore(self.pool_size) + self._init_lock = asyncio.Lock() + self._initialized = False + self._closed = False + self.log.info("SFTPClientPool initialised...") + + async def _ensure_initialized(self): + """Ensure pool is usable (also handles re-opening after close).""" + if self._initialized and not self._closed: + return + + async with self._init_lock: + if not self._initialized or self._closed: + self.log.info("Initializing / resetting SFTPClientPool for '%s'", self.sftp_conn_id) + self._idle = asyncio.LifoQueue() + self._in_use.clear() + self._semaphore = asyncio.Semaphore(self.pool_size) + self._closed = False + self._initialized = True + + async def _create_connection( + self, + ) -> tuple[asyncssh.SSHClientConnection, asyncssh.SFTPClient]: + ssh_conn = await SFTPHookAsync(sftp_conn_id=self.sftp_conn_id)._get_conn() + sftp = await ssh_conn.start_sftp_client() + self.log.info("Created new SFTP connection for sftp_conn_id '%s'", self.sftp_conn_id) + return ssh_conn, sftp + + async def acquire(self): + await self._ensure_initialized() + + if self._closed: + raise RuntimeError("Cannot acquire from a closed SFTPClientPool") + + self.log.debug("Acquiring SFTP connection for '%s'", self.sftp_conn_id) + + await self._semaphore.acquire() + + try: + try: + pair = self._idle.get_nowait() + except asyncio.QueueEmpty: + pair = await self._create_connection() + + self._in_use.add(pair) + return pair + except Exception: + self._semaphore.release() + raise + + async def release(self, pair): + if pair not in self._in_use: + self.log.warning("Attempted to release unknown or already released connection") + return + + self._in_use.discard(pair) + + if self._closed: + ssh, sftp = pair + with suppress(Exception): + sftp.exit() + with suppress(Exception): + ssh.close() + else: + await self._idle.put(pair) + + self.log.debug("Releasing SFTP connection for '%s'", self.sftp_conn_id) + self._semaphore.release() + + @asynccontextmanager + async def get_sftp_client(self): + await self._ensure_initialized() + pair = None + try: + pair = await self.acquire() + ssh, sftp = pair + yield sftp + except BaseException as e: + self.log.warning("Dropping faulty connection for '%s': %s", self.sftp_conn_id, e) + if pair: + ssh, sftp = pair + self._in_use.discard(pair) + with suppress(Exception): + sftp.exit() + with suppress(Exception): + ssh.close() + self._semaphore.release() + raise + else: + await self.release(pair) + + async def close(self): + """Gracefully shutdown all connections in the pool.""" + async with self._init_lock: + if self._closed: + return + + self._closed = True + + self.log.info("Closing all SFTP connections for '%s'", self.sftp_conn_id) + + while not self._idle.empty(): + ssh, sftp = await self._idle.get() + with suppress(Exception): + sftp.exit() + with suppress(Exception): + ssh.close() + + for pair in list(self._in_use): + ssh, sftp = pair + with suppress(Exception): + sftp.exit() + with suppress(Exception): + ssh.close() + self._in_use.discard(pair) + + if self._in_use: + self.log.warning("Pool closed with %d active connections", len(self._in_use)) + Review Comment: In `close()`, `_in_use` is fully drained inside the loop, so the subsequent `if self._in_use:` warning will never trigger. If you want to warn when closing with active connections, capture `len(self._in_use)` before the loop (or warn based on the list you iterate) and log that value. ########## providers/sftp/src/airflow/providers/sftp/hooks/sftp.py: ########## @@ -789,24 +793,117 @@ def _get_value(self_val, conn_val, default=None): ssh_client_conn = await asyncssh.connect(**conn_config) return ssh_client_conn - async def list_directory(self, path: str = "") -> list[str] | None: # type: ignore[return] - """Return a list of files on the SFTP server at the provided path.""" + async def retrieve_file( + self, + remote_full_path: str, + local_full_path: str | BytesIO, + encoding: str = "utf-8", + chunk_size: int = CHUNK_SIZE, + ) -> None: + """ + Transfer the remote file to a local location asynchronously. + + If local_full_path is a string path, the file will be put at that location. + If it is a BytesIO or file-like object, the file will be streamed into it. + + :param remote_full_path: Full path to the remote file. + :param local_full_path: Full path to the local file or a file-like buffer. + :param encoding: Encoding to use for reading the remote file (default: "utf-8"). + :param chunk_size: Size of chunks to read at a time (default: 64KB). + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + async with sftp.open(remote_full_path, encoding=encoding) as remote_file: + if isinstance(local_full_path, BytesIO): + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + local_full_path.write(chunk.encode(encoding)) + local_full_path.seek(0) + else: + async with aiofiles.open(local_full_path, "wb") as f: + while True: + chunk = await remote_file.read(chunk_size) + if not chunk: + break + await f.write(chunk) + + async def store_file(self, remote_full_path: str, local_full_path: str | bytes | BytesIO) -> None: + """ + Transfer a local file to the remote location. + + If local_full_path_or_buffer is a string path, the file will be read + from that location. + + :param remote_full_path: full path to the remote file + :param local_full_path: full path to the local file or a file-like buffer + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + if isinstance(local_full_path, bytes): + local_full_path = BytesIO(local_full_path) + + if isinstance(local_full_path, BytesIO): + with suppress(asyncssh.SFTPFailure): + remote_path = PurePosixPath(remote_full_path) + await sftp.makedirs(str(remote_path.parent)) + + async with sftp.open(remote_full_path, "wb") as f: + local_full_path.seek(0) + data = local_full_path.read() + await f.write(data) + else: + await sftp.put(str(local_full_path), remote_full_path) + + async def mkdir(self, path: str) -> None: + """ + Create a directory on the remote system asynchronously. + + The default permissions are determined by the server. Parent directories are created as needed. + + :param path: Full path to the remote directory to create. + """ + async with await self._get_conn() as ssh_conn: + async with ssh_conn.start_sftp_client() as sftp: + await sftp.makedirs(path) + + async def list_directory(self, path: str) -> list[str] | None: + """ + List files in a directory on the remote system asynchronously. + + Recursively lists all files under the given directory path. + + :param path: Full path to the remote directory to list. + :return: List of file paths found under the directory, or None if the directory does not exist. + """ async with await self._get_conn() as ssh_conn: - sftp_client = await ssh_conn.start_sftp_client() - try: - files = await sftp_client.listdir(path) - return sorted(files) - except asyncssh.SFTPNoSuchFile: - return None + async with ssh_conn.start_sftp_client() as sftp: + + async def walk(dir_path: str): + results = [] + files = await sftp.readdir(dir_path) + + for file in files: + if file.filename not in {".", ".."}: + file_path = posixpath.join(dir_path, file.filename) + if stat.S_ISDIR(file.attrs.permissions): + results.extend(await walk(file_path)) + else: + results.append(file_path) + + return results + + return await walk(path) Review Comment: `list_directory()` docstring says it returns `None` when the directory does not exist, but the implementation doesn't catch `asyncssh.SFTPNoSuchFile` (unlike `read_directory()` below) and will raise instead. If the intent is to keep the previous behavior, wrap the initial `readdir()` (or the whole walk) in a `try/except asyncssh.SFTPNoSuchFile` and return `None`. ```suggestion results: list[str] = [] try: files = await sftp.readdir(dir_path) except asyncssh.SFTPNoSuchFile: # If the top-level directory does not exist, signal this to the caller. if dir_path == path: return None # For subdirectories, preserve the previous behavior and propagate the error. raise for file in files: if file.filename not in {".", ".."}: file_path = posixpath.join(dir_path, file.filename) if stat.S_ISDIR(file.attrs.permissions): results.extend(await walk(file_path)) # type: ignore[arg-type] else: results.append(file_path) return results result = await walk(path) if result is None: return None return result ``` ########## providers/sftp/tests/unit/sftp/pools/test_sftp.py: ########## @@ -0,0 +1,72 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import pytest +from airflow.providers.sftp.pools.sftp import SFTPClientPool + + [email protected] +class TestSFTPClientPool: + @pytest.mark.asyncio + async def test_acquire_and_release(self, sftp_hook_mocked): + async with SFTPClientPool("test_conn", pool_size=2) as pool: + ssh, sftp = await pool.acquire() + assert ssh is not None + assert sftp is not None + + await pool.release((ssh, sftp)) + ssh2, sftp2 = await pool.acquire() + assert ssh2 is not None + assert sftp2 is not None + + @pytest.mark.asyncio + async def test_get_sftp_client_context_manager(self, sftp_hook_mocked): + async with SFTPClientPool("test_conn", pool_size=2) as pool: + assert pool is not None + Review Comment: `test_get_sftp_client_context_manager` currently only asserts `pool is not None`, which doesn't exercise the `get_sftp_client()` context manager behavior. Consider updating the test to `async with pool.get_sftp_client() as sftp:` and assert the yielded client is usable, and that it is returned to the pool afterward. ```suggestion async with SFTPClientPool("test_conn", pool_size=1) as pool: async with pool.get_sftp_client() as sftp: assert sftp is not None # basic usability check: SFTP client should expose some attributes assert hasattr(sftp, "listdir") or hasattr(sftp, "put") or hasattr(sftp, "get") # After the context manager exits, the client should be returned to the pool ssh2, sftp2 = await pool.acquire() assert ssh2 is not None assert sftp2 is not None ``` ########## providers/sftp/src/airflow/providers/sftp/pools/sftp.py: ########## @@ -0,0 +1,183 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +from __future__ import annotations + +import asyncio +from contextlib import asynccontextmanager, suppress +from threading import Lock + +import asyncssh + +from airflow.configuration import conf +from airflow.providers.sftp.hooks.sftp import SFTPHookAsync +from airflow.sdk.definitions._internal.logging_mixin import LoggingMixin + + +class SFTPClientPool(LoggingMixin): + """Lazy Thread-safe and Async-safe Singleton SFTP pool that keeps SSH and SFTP clients alive until exit, and limits concurrent usage to pool_size.""" + Review Comment: The class docstring claims the pool is "Thread-safe and Async-safe", but after construction it mutates `_idle`, `_in_use`, and uses asyncio primitives without any thread synchronization. Either remove the thread-safety claim or add proper thread-safe boundaries (or clearly document that the pool must be used within a single event loop/thread). ```suggestion """Lazy async-safe Singleton SFTP pool. This pool keeps SSH and SFTP clients alive until exit and limits concurrent usage to ``pool_size``. It is intended to be used from a single asyncio event loop/thread. Threading is used only to protect singleton instance creation; acquiring and releasing connections must not be done concurrently from multiple threads. """ ``` ########## providers/sftp/tests/conftest.py: ########## @@ -16,4 +16,36 @@ # under the License. from __future__ import annotations +from collections.abc import Generator +from typing import TYPE_CHECKING, Any +from unittest.mock import AsyncMock, patch + +import pytest +from asyncssh import SFTPClient, SSHClientConnection + +if TYPE_CHECKING: + from airflow.providers.sftp.hooks.sftp import SFTPHookAsync + pytest_plugins = "tests_common.pytest_plugin" + + [email protected] +def sftp_hook_mocked() -> Generator[tuple[SFTPHookAsync, SFTPClient], Any, None]: + """ + Fixture that mocks SFTPHookAsync._get_conn with SSH + SFTP async mocks. + Returns a tuple (hook, sftp_client_mock) so tests can easily set readdir. + """ + from airflow.providers.sftp.hooks.sftp import SFTPHookAsync + + sftp_client_mock = AsyncMock(spec=SFTPClient) Review Comment: Importing `SFTPHookAsync` inside the fixture function adds an in-function import (against the repo’s general style) and isn't needed here because it's already available at test import time. Consider moving `from airflow.providers.sftp.hooks.sftp import SFTPHookAsync` to module scope (the existing `TYPE_CHECKING` block can remain). -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
