Refactor the logic to identify nodes. Add facility to retrieve current
nodes from context.

Signed-off-by: Luca Vizzarro <luca.vizza...@arm.com>
Reviewed-by: Paul Szczepanek <paul.szczepa...@arm.com>
---
 dts/framework/testbed_model/node.py     | 38 +++++++++++++++++++++++++
 dts/framework/testbed_model/topology.py |  8 ++----
 2 files changed, 40 insertions(+), 6 deletions(-)

diff --git a/dts/framework/testbed_model/node.py 
b/dts/framework/testbed_model/node.py
index e6737cd173..474157490d 100644
--- a/dts/framework/testbed_model/node.py
+++ b/dts/framework/testbed_model/node.py
@@ -15,6 +15,7 @@
 
 from functools import cached_property
 from pathlib import PurePath
+from typing import Literal, TypeAlias
 
 from framework.config.node import (
     OS,
@@ -202,3 +203,40 @@ def create_session(node_config: NodeConfiguration, name: 
str, logger: DTSLogger)
             return LinuxSession(node_config, name, logger)
         case _:
             raise ConfigurationError(f"Unsupported OS {node_config.os}")
+
+
+LocalNodeIdentifier: TypeAlias = Literal["local"]
+"""Local node identifier for testbed model."""
+
+RemoteNodeIdentifier: TypeAlias = Literal["sut", "tg"]
+"""Remote node identifiers for testbed model."""
+
+NodeIdentifier: TypeAlias = Literal["local", "sut", "tg"]
+"""Node identifiers for testbed model."""
+
+
+def get_node(node_identifier: NodeIdentifier) -> Node | None:
+    """Get the node based on the identifier.
+
+    Args:
+        node_identifier: The identifier of the node.
+
+    Returns:
+        The node object corresponding to the identifier, or :data:`None` if 
the identifier is
+            "local".
+
+    Raises:
+        InternalError: If the node identifier is unknown.
+    """
+    if node_identifier == "local":
+        return None
+
+    from framework.context import get_ctx
+
+    ctx = get_ctx()
+    if node_identifier == "sut":
+        return ctx.sut_node
+    elif node_identifier == "tg":
+        return ctx.tg_node
+    else:
+        raise InternalError(f"Unknown node identifier: {node_identifier}")
diff --git a/dts/framework/testbed_model/topology.py 
b/dts/framework/testbed_model/topology.py
index 899ea0ad3a..9fc056b330 100644
--- a/dts/framework/testbed_model/topology.py
+++ b/dts/framework/testbed_model/topology.py
@@ -12,12 +12,12 @@
 from collections.abc import Iterator
 from dataclasses import dataclass
 from enum import Enum
-from typing import Literal, NamedTuple
+from typing import NamedTuple
 
 from typing_extensions import Self
 
 from framework.exception import ConfigurationError, InternalError
-from framework.testbed_model.node import Node
+from framework.testbed_model.node import Node, NodeIdentifier
 
 from .port import DriverKind, Port, PortConfig
 
@@ -47,10 +47,6 @@ class PortLink(NamedTuple):
     tg_port: Port
 
 
-NodeIdentifier = Literal["sut", "tg"]
-"""The node identifier."""
-
-
 @dataclass(frozen=True)
 class Topology:
     """Testbed topology.
-- 
2.43.0

Reply via email to