Refactor the logic to identify nodes. Add facility to retrieve current nodes from context.
Signed-off-by: Luca Vizzarro <luca.vizza...@arm.com> Reviewed-by: Paul Szczepanek <paul.szczepa...@arm.com> --- dts/framework/testbed_model/node.py | 38 +++++++++++++++++++++++++ dts/framework/testbed_model/topology.py | 8 ++---- 2 files changed, 40 insertions(+), 6 deletions(-) diff --git a/dts/framework/testbed_model/node.py b/dts/framework/testbed_model/node.py index e6737cd173..474157490d 100644 --- a/dts/framework/testbed_model/node.py +++ b/dts/framework/testbed_model/node.py @@ -15,6 +15,7 @@ from functools import cached_property from pathlib import PurePath +from typing import Literal, TypeAlias from framework.config.node import ( OS, @@ -202,3 +203,40 @@ def create_session(node_config: NodeConfiguration, name: str, logger: DTSLogger) return LinuxSession(node_config, name, logger) case _: raise ConfigurationError(f"Unsupported OS {node_config.os}") + + +LocalNodeIdentifier: TypeAlias = Literal["local"] +"""Local node identifier for testbed model.""" + +RemoteNodeIdentifier: TypeAlias = Literal["sut", "tg"] +"""Remote node identifiers for testbed model.""" + +NodeIdentifier: TypeAlias = Literal["local", "sut", "tg"] +"""Node identifiers for testbed model.""" + + +def get_node(node_identifier: NodeIdentifier) -> Node | None: + """Get the node based on the identifier. + + Args: + node_identifier: The identifier of the node. + + Returns: + The node object corresponding to the identifier, or :data:`None` if the identifier is + "local". + + Raises: + InternalError: If the node identifier is unknown. + """ + if node_identifier == "local": + return None + + from framework.context import get_ctx + + ctx = get_ctx() + if node_identifier == "sut": + return ctx.sut_node + elif node_identifier == "tg": + return ctx.tg_node + else: + raise InternalError(f"Unknown node identifier: {node_identifier}") diff --git a/dts/framework/testbed_model/topology.py b/dts/framework/testbed_model/topology.py index 899ea0ad3a..9fc056b330 100644 --- a/dts/framework/testbed_model/topology.py +++ b/dts/framework/testbed_model/topology.py @@ -12,12 +12,12 @@ from collections.abc import Iterator from dataclasses import dataclass from enum import Enum -from typing import Literal, NamedTuple +from typing import NamedTuple from typing_extensions import Self from framework.exception import ConfigurationError, InternalError -from framework.testbed_model.node import Node +from framework.testbed_model.node import Node, NodeIdentifier from .port import DriverKind, Port, PortConfig @@ -47,10 +47,6 @@ class PortLink(NamedTuple): tg_port: Port -NodeIdentifier = Literal["sut", "tg"] -"""The node identifier.""" - - @dataclass(frozen=True) class Topology: """Testbed topology. -- 2.43.0