Merge "Add controling for Amarisoft devices."
diff --git a/acts/framework/acts/controllers/amarisoft_lib/amarisoft_client.py b/acts/framework/acts/controllers/amarisoft_lib/amarisoft_client.py
new file mode 100644
index 0000000..e440696
--- /dev/null
+++ b/acts/framework/acts/controllers/amarisoft_lib/amarisoft_client.py
@@ -0,0 +1,222 @@
+#!/usr/bin/env python3

+#

+#   Copyright 2022 - Google

+#

+#   Licensed under the Apache License, Version 2.0 (the "License");

+#   you may not use this file except in compliance with the License.

+#   You may obtain a copy of the License at

+#

+#       http://www.apache.org/licenses/LICENSE-2.0

+#

+#   Unless required by applicable law or agreed to in writing, software

+#   distributed under the License is distributed on an "AS IS" BASIS,

+#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.

+#   See the License for the specific language governing permissions and

+#   limitations under the License.

+

+import asyncio

+import json

+import logging

+from typing import Any, Mapping, Optional, Tuple

+

+from acts.controllers.amarisoft_lib import ssh_utils

+import immutabledict

+import websockets

+

+_CONFIG_DIR_MAPPING = immutabledict.immutabledict({

+    'enb': '/root/enb/config/enb.cfg',

+    'mme': '/root/mme/config/mme.cfg',

+    'ims': '/root/mme/config/ims.cfg',

+    'mbms': '/config/mbmsgw.cfg',

+    'ots': '/config/ots.cfg'

+})

+

+

+class MessageFailureError(Exception):

+  """Raises an error when the message execution fail."""

+

+

+class AmariSoftClient(ssh_utils.RemoteClient):

+  """The SSH client class interacts with Amarisoft.

+

+    A simulator used to simulate the base station can output different signals

+    according to the network configuration settings.

+    For example: T Mobile NSA LTE band 66 + NR band 71.

+  """

+

+  async def _send_message_to_callbox(self, uri: str,

+                                     msg: str) -> Tuple[str, str]:

+    """Implements async function for send message to the callbox.

+

+    Args:

+      uri: The uri of specific websocket interface.

+      msg: The message to be send to callbox.

+

+    Returns:

+      The response from callbox.

+    """

+    async with websockets.connect(

+        uri, extra_headers={'origin': 'Test'}) as websocket:

+      await websocket.send(msg)

+      head = await websocket.recv()

+      body = await websocket.recv()

+    return head, body

+

+  def send_message(self, port: str, msg: str) -> Tuple[str, str]:

+    """Sends a message to the callbox.

+

+    Args:

+      port: The port of specific websocket interface.

+      msg: The message to be send to callbox.

+

+    Returns:

+      The response from callbox.

+    """

+    return asyncio.get_event_loop().run_until_complete(

+        self._send_message_to_callbox(f'ws://{self.host}:{port}/', msg))

+

+  def verify_response(self, func: str, head: str,

+                      body: str) -> Mapping[str, Any]:

+    """Makes sure there are no error messages in Amarisoft's response.

+

+    If a message produces an error, response will have an error string field

+    representing the error.

+    For example:

+      {

+        "message": "ready",

+        "message_id": <message id>,

+        "error": <error message>,

+        "type": "ENB",

+        "name: <name>,

+      }

+

+    Args:

+      func: The message send to Amarisoft.

+      head: Responsed message head.

+      body: Responsed message body.

+

+    Returns:

+      Standard output of the shell command.

+

+    Raises:

+       MessageFailureError: Raised when an error occurs in the response message.

+    """

+    loaded_head = json.loads(head)

+    loaded_body = json.loads(body)

+

+    if loaded_head.get('message') != 'ready':

+      raise MessageFailureError(

+          f'Fail to get response from callbox, message: {loaded_head["error"]}')

+    if 'error' in loaded_body:

+      raise MessageFailureError(

+          f'Fail to excute {func} with error message: {loaded_body["error"]}')

+    if loaded_body.get('message') != func:

+      raise MessageFailureError(

+          f'The message sent was {loaded_body["message"]} instead of {func}.')

+    return loaded_head, loaded_body

+

+  def lte_service_stop(self) -> None:

+    """Stops to output signal."""

+    self.ssh.run_cmd('systemctl stop lte')

+

+  def lte_service_start(self):

+    """Starts to output signal."""

+    self.ssh.run_cmd('systemctl start lte')

+

+  def lte_service_restart(self):

+    """Restarts to output signal."""

+    self.ssh.run_cmd('systemctl restart lte')

+

+  def lte_service_enable(self):

+    """lte service remains enable until next reboot."""

+    self.ssh.run_cmd('systemctl enable lte')

+

+  def lte_service_disable(self):

+    """lte service remains disable until next reboot."""

+    self.ssh.run_cmd('systemctl disable lte')

+

+  def lte_service_is_active(self) -> bool:

+    """Checks lte service is active or not.

+

+    Returns:

+      True if service active, False otherwise.

+    """

+    return not any('inactive' in line

+                   for line in self.ssh.run_cmd('systemctl is-active lte'))

+

+  def set_config_dir(self, cfg_type: str, path: str) -> None:

+    """Sets the path of target configuration file.

+

+    Args:

+      cfg_type: The type of target configuration. (e.g. mme, enb ...etc.)

+      path: The path of target configuration. (e.g.

+        /root/lteenb-linux-2020-12-14)

+    """

+    path_old = self.get_config_dir(cfg_type)

+    if path != path_old:

+      logging.info('set new path %s (was %s)', path, path_old)

+      self.ssh.run_cmd(f'ln -sfn {path} /root/{cfg_type}')

+    else:

+      logging.info('path %s does not change.', path_old)

+

+  def get_config_dir(self, cfg_type: str) -> Optional[str]:

+    """Gets the path of target configuration.

+

+    Args:

+      cfg_type: Target configuration type. (e.g. mme, enb...etc.)

+

+    Returns:

+      The path of configuration.

+    """

+    result = self.ssh.run_cmd(f'readlink /root/{cfg_type}')

+    if result:

+      path = result[0].strip()

+    else:

+      logging.warning('%s path not found.', cfg_type)

+      return None

+    return path

+

+  def set_config_file(self, cfg_type: str, cfg_file: str) -> None:

+    """Sets the configuration to be executed.

+

+    Args:

+      cfg_type: The type of target configuration. (e.g. mme, enb...etc.)

+      cfg_file: The configuration to be executed. (e.g.

+        /root/lteenb-linux-2020-12-14/config/gnb.cfg )

+

+    Raises:

+      FileNotFoundError: Raised when a file or directory is requested but

+      doesn’t exist.

+    """

+    cfg_link = self.get_config_dir(cfg_type) + _CONFIG_DIR_MAPPING[cfg_type]

+    if not self.ssh.is_file_exist(cfg_file):

+      raise FileNotFoundError("The command file doesn't exist")

+    self.ssh.run_cmd(f'ln -sfn {cfg_file} {cfg_link}')

+

+  def get_config_file(self, cfg_type: str) -> Optional[str]:

+    """Gets the current configuration of specific configuration type.

+

+    Args:

+      cfg_type: The type of target configuration. (e.g. mme, enb...etc.)

+

+    Returns:

+      The current configuration with absolute path.

+    """

+    cfg_path = self.get_config_dir(cfg_type) + _CONFIG_DIR_MAPPING[cfg_type]

+    if cfg_path:

+      result = self.ssh.run_cmd(f'readlink {cfg_path}')

+      if result:

+        return result[0].strip()

+

+  def get_all_config_dir(self) -> Mapping[str, str]:

+    """Gets all configuration directions.

+

+    Returns:

+      All configuration directions.

+    """

+    config_dir = {}

+    for cfg_type in ('ots', 'enb', 'mme', 'mbms'):

+      config_dir[cfg_type] = self.get_config_dir(cfg_type)

+      logging.debug('get path of %s: %s', cfg_type, config_dir[cfg_type])

+    return config_dir

+

diff --git a/acts/framework/acts/controllers/amarisoft_lib/amarisoft_constants.py b/acts/framework/acts/controllers/amarisoft_lib/amarisoft_constants.py
new file mode 100644
index 0000000..c62bf2a
--- /dev/null
+++ b/acts/framework/acts/controllers/amarisoft_lib/amarisoft_constants.py
@@ -0,0 +1,14 @@
+"""Constants for test."""

+

+

+# ports of lte service websocket interface

+class PortNumber:

+  URI_MME = '9000'

+  URI_ENB = '9001'

+  URI_UE = '9002'

+  URI_IMS = '9003'

+  URI_MBMS = '9004'

+  URI_PROBE = '9005'

+  URI_LICENSE = '9006'

+  URI_MON = '9007'

+  URI_VIEW = '9008'

diff --git a/acts/framework/acts/controllers/amarisoft_lib/ssh_utils.py b/acts/framework/acts/controllers/amarisoft_lib/ssh_utils.py
new file mode 100644
index 0000000..856c612
--- /dev/null
+++ b/acts/framework/acts/controllers/amarisoft_lib/ssh_utils.py
@@ -0,0 +1,181 @@
+"""ssh utils."""

+

+import logging

+from typing import Sequence

+

+import paramiko

+

+COMMAND_RETRY_TIMES = 3

+

+

+class RunCommandError(Exception):

+  """Raises an error when run command fail."""

+

+

+class NotConnectedError(Exception):

+  """Raises an error when run command without SSH connect."""

+

+

+class RemoteClient:

+  """The SSH client class interacts with the test machine.

+

+  Attributes:

+    host: A string representing the IP address of amarisoft.

+    port: A string representing the default port of SSH.

+    username: A string representing the username of amarisoft.

+    password: A string representing the password of amarisoft.

+    ssh: A SSH client.

+    sftp: A SFTP client.

+  """

+

+  def __init__(self,

+               host: str,

+               username: str,

+               password: str,

+               port: str = '22') -> None:

+    self.host = host

+    self.port = port

+    self.username = username

+    self.password = password

+    self.ssh = paramiko.SSHClient()

+    self.sftp = None

+

+  def ssh_is_connected(self) -> bool:

+    """Checks SSH connect or not.

+

+    Returns:

+      True if SSH is connected, False otherwise.

+    """

+    return self.ssh and self.ssh.get_transport().is_active()

+

+  def ssh_close(self) -> bool:

+    """Closes the SSH connection.

+

+    Returns:

+      True if ssh session closed, False otherwise.

+    """

+    for _ in range(COMMAND_RETRY_TIMES):

+      if self.ssh_is_connected():

+        self.ssh.close()

+      else:

+        return True

+    return False

+

+  def connect(self) -> bool:

+    """Creats SSH connection.

+

+    Returns:

+      True if success, False otherwise.

+    """

+    for _ in range(COMMAND_RETRY_TIMES):

+      try:

+        self.ssh.set_missing_host_key_policy(paramiko.AutoAddPolicy())

+        self.ssh.connect(self.host, self.port, self.username, self.password)

+        self.ssh.get_transport().set_keepalive(1)

+        self.sftp = paramiko.SFTPClient.from_transport(self.ssh.get_transport())

+        return True

+      except Exception:  # pylint: disable=broad-except

+        self.ssh_close()

+    return False

+

+  def run_cmd(self, cmd: str) -> Sequence[str]:

+    """Runs shell command.

+

+    Args:

+      cmd: Command to be executed.

+

+    Returns:

+      Standard output of the shell command.

+

+    Raises:

+       RunCommandError: Raise error when command failed.

+       NotConnectedError: Raised when run command without SSH connect.

+    """

+    if not self.ssh_is_connected():

+      raise NotConnectedError('ssh remote has not been established')

+

+    logging.debug('ssh remote -> %s', cmd)

+    _, stdout, stderr = self.ssh.exec_command(cmd)

+    err = stderr.readlines()

+    if err:

+      logging.error('command failed.')

+      raise RunCommandError(err)

+    return stdout.readlines()

+

+  def is_file_exist(self, file: str) -> bool:

+    """Checks target file exist.

+

+    Args:

+        file: Target file with absolute path.

+

+    Returns:

+        True if file exist, false otherwise.

+    """

+    return any('exist' in line for line in self.ssh_run_cmd(

+        f'if [ -f "{file}" ]; then echo -e "exist"; fi'))

+

+  def sftp_upload(self, src: str, dst: str) -> bool:

+    """Uploads a local file to remote side.

+

+    Args:

+      src: The target file with absolute path.

+      dst: The absolute path to put the file with file name.

+      For example:

+        upload('/usr/local/google/home/zoeyliu/Desktop/sample_config.yml',

+        '/root/sample_config.yml')

+

+    Returns:

+      True if file upload success, False otherwise.

+

+    Raises:

+       NotConnectedError: Raised when run command without SSH connect.

+    """

+    if not self.ssh_is_connected():

+      raise NotConnectedError('ssh remote has not been established')

+    if not self.sftp:

+      raise NotConnectedError('sftp remote has not been established')

+

+    logging.info('[local] %s -> [remote] %s', src, dst)

+    self.sftp.put(src, dst)

+    return self.is_file_exist(dst)

+

+  def sftp_download(self, src: str, dst: str) -> bool:

+    """Downloads a file to local.

+

+    Args:

+      src: The target file with absolute path.

+      dst: The absolute path to put the file.

+

+    Returns:

+      True if file download success, False otherwise.

+

+    Raises:

+       NotConnectedError: Raised when run command without SSH connect.

+    """

+    if not self.ssh_is_connected():

+      raise NotConnectedError('ssh remote has not been established')

+    if not self.sftp:

+      raise NotConnectedError('sftp remote has not been established')

+

+    logging.info('[remote] %s -> [local] %s', src, dst)

+    self.sftp.get(src, dst)

+    return self.is_file_exist(dst)

+

+  def sftp_list_dir(self, path: str) -> Sequence[str]:

+    """Lists the names of the entries in the given path.

+

+    Args:

+      path: The path of the list.

+

+    Returns:

+      The names of the entries in the given path.

+

+    Raises:

+       NotConnectedError: Raised when run command without SSH connect.

+    """

+    if not self.ssh_is_connected():

+      raise NotConnectedError('ssh remote has not been established')

+    if not self.sftp:

+      raise NotConnectedError('sftp remote has not been established')

+    return sorted(self.sftp.listdir(path))

+