Merge the RemoteSession class with SSHSession as there is no current use of a separate channel other than SSH.
Signed-off-by: Luca Vizzarro <luca.vizza...@arm.com> Reviewed-by: Paul Szczepanek <paul.szczepa...@arm.com> --- doc/api/dts/framework.remote_session.rst | 1 - .../framework.remote_session.ssh_session.rst | 8 -- dts/framework/remote_session/__init__.py | 3 +- .../remote_session/remote_session.py | 110 ++++++++++++----- dts/framework/remote_session/ssh_session.py | 116 ------------------ 5 files changed, 79 insertions(+), 159 deletions(-) delete mode 100644 doc/api/dts/framework.remote_session.ssh_session.rst delete mode 100644 dts/framework/remote_session/ssh_session.py diff --git a/doc/api/dts/framework.remote_session.rst b/doc/api/dts/framework.remote_session.rst index 27c9153e64..e4ebd21ab9 100644 --- a/doc/api/dts/framework.remote_session.rst +++ b/doc/api/dts/framework.remote_session.rst @@ -12,7 +12,6 @@ remote\_session - Node Connections Package :maxdepth: 1 framework.remote_session.remote_session - framework.remote_session.ssh_session framework.remote_session.interactive_remote_session framework.remote_session.interactive_shell framework.remote_session.shell_pool diff --git a/doc/api/dts/framework.remote_session.ssh_session.rst b/doc/api/dts/framework.remote_session.ssh_session.rst deleted file mode 100644 index 4bb51d7db2..0000000000 --- a/doc/api/dts/framework.remote_session.ssh_session.rst +++ /dev/null @@ -1,8 +0,0 @@ -.. SPDX-License-Identifier: BSD-3-Clause - -ssh\_session - SSH Remote Session -================================= - -.. automodule:: framework.remote_session.ssh_session - :members: - :show-inheritance: diff --git a/dts/framework/remote_session/__init__.py b/dts/framework/remote_session/__init__.py index 1a5cf6abd3..2c0922acd2 100644 --- a/dts/framework/remote_session/__init__.py +++ b/dts/framework/remote_session/__init__.py @@ -17,7 +17,6 @@ from .interactive_remote_session import InteractiveRemoteSession from .remote_session import RemoteSession -from .ssh_session import SSHSession def create_remote_session( @@ -36,7 +35,7 @@ def create_remote_session( Returns: The SSH remote session. """ - return SSHSession(node_config, name, logger) + return RemoteSession(node_config, name, logger) def create_interactive_session( diff --git a/dts/framework/remote_session/remote_session.py b/dts/framework/remote_session/remote_session.py index 89d4618c41..cdb2dc1ed5 100644 --- a/dts/framework/remote_session/remote_session.py +++ b/dts/framework/remote_session/remote_session.py @@ -4,18 +4,33 @@ # Copyright(c) 2022-2023 University of New Hampshire # Copyright(c) 2024 Arm Limited -"""Base remote session. +"""SSH remote session.""" -This module contains the abstract base class for remote sessions and defines -the structure of the result of a command execution. -""" - -from abc import ABC, abstractmethod +import socket +import traceback from dataclasses import InitVar, dataclass, field from pathlib import Path, PurePath +from fabric import Connection # type: ignore[import-untyped] +from invoke.exceptions import ( + CommandTimedOut, + ThreadException, + UnexpectedExit, +) +from paramiko.ssh_exception import ( + AuthenticationException, + BadHostKeyException, + NoValidConnectionsError, + SSHException, +) + from framework.config.node import NodeConfiguration -from framework.exception import RemoteCommandExecutionError +from framework.exception import ( + RemoteCommandExecutionError, + SSHConnectionError, + SSHSessionDeadError, + SSHTimeoutError, +) from framework.logger import DTSLogger from framework.settings import SETTINGS @@ -63,14 +78,11 @@ def __str__(self) -> str: ) -class RemoteSession(ABC): +class RemoteSession: """Non-interactive remote session. - The abstract methods must be implemented in order to connect to a remote host (node) - and maintain a remote session. - The subclasses must use (or implement) some underlying transport protocol (e.g. SSH) - to implement the methods. On top of that, it provides some basic services common to all - subclasses, such as keeping history and logging what's being executed on the remote node. + The connection is implemented with + `the Fabric Python library <https://docs.fabfile.org/en/latest/>`_. Attributes: name: The name of the session. @@ -82,6 +94,7 @@ class RemoteSession(ABC): password: The password used in the connection. Most frequently empty, as the use of passwords is discouraged. history: The executed commands during this session. + session: The underlying Fabric SSH session. """ name: str @@ -91,6 +104,7 @@ class RemoteSession(ABC): username: str password: str history: list[CommandResult] + session: Connection _logger: DTSLogger _node_config: NodeConfiguration @@ -128,7 +142,6 @@ def __init__( self._connect() self._logger.info(f"Connection to {self.username}@{self.hostname} successful.") - @abstractmethod def _connect(self) -> None: """Create a connection to the node. @@ -137,7 +150,42 @@ def _connect(self) -> None: The implementation must except all exceptions and convert them to an SSHConnectionError. The implementation may optionally implement retry attempts. + + Raises: + SSHConnectionError: If the connection to the node was not successful. """ + errors = [] + retry_attempts = 10 + login_timeout = 20 if self.port else 10 + for retry_attempt in range(retry_attempts): + try: + self.session = Connection( + self.ip, + user=self.username, + port=self.port, + connect_kwargs={"password": self.password}, + connect_timeout=login_timeout, + ) + self.session.open() + + except (ValueError, BadHostKeyException, AuthenticationException) as e: + self._logger.exception(e) + raise SSHConnectionError(self.hostname) from e + + except (NoValidConnectionsError, socket.error, SSHException) as e: + self._logger.debug(traceback.format_exc()) + self._logger.warning(e) + + error = repr(e) + if error not in errors: + errors.append(error) + + self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.") + + else: + break + else: + raise SSHConnectionError(self.hostname, errors) def send_command( self, @@ -166,7 +214,18 @@ def send_command( The output of the command along with the return code. """ self._logger.info(f"Sending: '{command}'" + (f" with env vars: '{env}'" if env else "")) - result = self._send_command(command, timeout, env) + + try: + output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout) + except (UnexpectedExit, ThreadException) as e: + self._logger.exception(e) + raise SSHSessionDeadError(self.hostname) from e + except CommandTimedOut as e: + self._logger.exception(e) + raise SSHTimeoutError(command) from e + + result = CommandResult(self.name, command, output.stdout, output.stderr, output.return_code) + if verify and result.return_code: self._logger.debug( f"Command '{command}' failed with return code '{result.return_code}'" @@ -178,24 +237,10 @@ def send_command( self.history.append(result) return result - @abstractmethod - def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult: - """Send a command to the connected node. - - The implementation must execute the command remotely with `env` environment variables - and return the result. - - The implementation must except all exceptions and raise: - - * SSHSessionDeadError if the session is not alive, - * SSHTimeoutError if the command execution times out. - """ - - @abstractmethod def is_alive(self) -> bool: """Check whether the remote session is still responding.""" + return self.session.is_connected - @abstractmethod def copy_from(self, source_file: str | PurePath, destination_dir: str | Path) -> None: """Copy a file from the remote Node to the local filesystem. @@ -207,8 +252,8 @@ def copy_from(self, source_file: str | PurePath, destination_dir: str | Path) -> destination_dir: The directory path on the local filesystem where the `source_file` will be saved. """ + self.session.get(str(source_file), str(destination_dir)) - @abstractmethod def copy_to(self, source_file: str | Path, destination_dir: str | PurePath) -> None: """Copy a file from local filesystem to the remote Node. @@ -220,7 +265,8 @@ def copy_to(self, source_file: str | Path, destination_dir: str | PurePath) -> N destination_dir: The directory path on the remote Node where the `source_file` will be saved. """ + self.session.put(str(source_file), str(destination_dir)) - @abstractmethod def close(self) -> None: """Close the remote session and free all used resources.""" + self.session.close() diff --git a/dts/framework/remote_session/ssh_session.py b/dts/framework/remote_session/ssh_session.py deleted file mode 100644 index e6e4704bc2..0000000000 --- a/dts/framework/remote_session/ssh_session.py +++ /dev/null @@ -1,116 +0,0 @@ -# SPDX-License-Identifier: BSD-3-Clause -# Copyright(c) 2023 PANTHEON.tech s.r.o. - -"""SSH remote session.""" - -import socket -import traceback -from pathlib import Path, PurePath - -from fabric import Connection # type: ignore[import-untyped] -from invoke.exceptions import ( - CommandTimedOut, - ThreadException, - UnexpectedExit, -) -from paramiko.ssh_exception import ( - AuthenticationException, - BadHostKeyException, - NoValidConnectionsError, - SSHException, -) - -from framework.exception import SSHConnectionError, SSHSessionDeadError, SSHTimeoutError - -from .remote_session import CommandResult, RemoteSession - - -class SSHSession(RemoteSession): - """A persistent SSH connection to a remote Node. - - The connection is implemented with - `the Fabric Python library <https://docs.fabfile.org/en/latest/>`_. - - Attributes: - session: The underlying Fabric SSH connection. - - Raises: - SSHConnectionError: The connection cannot be established. - """ - - session: Connection - - def _connect(self) -> None: - errors = [] - retry_attempts = 10 - login_timeout = 20 if self.port else 10 - for retry_attempt in range(retry_attempts): - try: - self.session = Connection( - self.ip, - user=self.username, - port=self.port, - connect_kwargs={"password": self.password}, - connect_timeout=login_timeout, - ) - self.session.open() - - except (ValueError, BadHostKeyException, AuthenticationException) as e: - self._logger.exception(e) - raise SSHConnectionError(self.hostname) from e - - except (NoValidConnectionsError, socket.error, SSHException) as e: - self._logger.debug(traceback.format_exc()) - self._logger.warning(e) - - error = repr(e) - if error not in errors: - errors.append(error) - - self._logger.info(f"Retrying connection: retry number {retry_attempt + 1}.") - - else: - break - else: - raise SSHConnectionError(self.hostname, errors) - - def _send_command(self, command: str, timeout: float, env: dict | None) -> CommandResult: - """Send a command and return the result of the execution. - - Args: - command: The command to execute. - timeout: Wait at most this long in seconds for the command execution to complete. - env: Extra environment variables that will be used in command execution. - - Raises: - SSHSessionDeadError: The session died while executing the command. - SSHTimeoutError: The command execution timed out. - """ - try: - output = self.session.run(command, env=env, warn=True, hide=True, timeout=timeout) - - except (UnexpectedExit, ThreadException) as e: - self._logger.exception(e) - raise SSHSessionDeadError(self.hostname) from e - - except CommandTimedOut as e: - self._logger.exception(e) - raise SSHTimeoutError(command) from e - - return CommandResult(self.name, command, output.stdout, output.stderr, output.return_code) - - def is_alive(self) -> bool: - """Overrides :meth:`~.remote_session.RemoteSession.is_alive`.""" - return self.session.is_connected - - def copy_from(self, source_file: str | PurePath, destination_dir: str | Path) -> None: - """Overrides :meth:`~.remote_session.RemoteSession.copy_from`.""" - self.session.get(str(source_file), str(destination_dir)) - - def copy_to(self, source_file: str | Path, destination_dir: str | PurePath) -> None: - """Overrides :meth:`~.remote_session.RemoteSession.copy_to`.""" - self.session.put(str(source_file), str(destination_dir)) - - def close(self) -> None: - """Overrides :meth:`~.remote_session.RemoteSession.close`.""" - self.session.close() -- 2.43.0