Merge "New function of Acloud restart AVD."
diff --git a/restart/restart.py b/restart/restart.py
index 43253d5..ec2e6e7 100644
--- a/restart/restart.py
+++ b/restart/restart.py
@@ -16,30 +16,64 @@
 This command will restart the CF AVD from a remote instance.
 """
 
-from __future__ import print_function
+import logging
+import subprocess
 
 from acloud import errors
+from acloud.internal import constants
+from acloud.internal.lib import utils
+from acloud.internal.lib.ssh import Ssh
+from acloud.internal.lib.ssh import IP
 from acloud.list import list as list_instances
 from acloud.public import config
 from acloud.public import report
+from acloud.reconnect import reconnect
 
 
-def RestartFromInstance(instance, instance_id):
+logger = logging.getLogger(__name__)
+
+
+def RestartFromInstance(cfg, instance, instance_id):
     """Restart AVD from remote CF instance.
 
     Args:
+        cfg: AcloudConfig object.
         instance: list.Instance() object.
         instance_id: Integer of the instance id.
 
     Returns:
         A Report instance.
     """
-    # TODO(162382338): rewrite this function to restart AVD from the remote instance.
-    print("We will restart AVD id (%s) from the instance: %s."
-          % (instance_id, instance.name))
+    ssh = Ssh(ip=IP(ip=instance.ip),
+              user=constants.GCE_USER,
+              ssh_private_key_path=cfg.ssh_private_key_path,
+              extra_args_ssh_tunnel=cfg.extra_args_ssh_tunnel)
+    logger.info("Start to restart AVD id (%s) from the instance: %s.",
+                instance_id, instance.name)
+    RestartDevice(ssh, instance_id)
+    reconnect.ReconnectInstance(cfg.ssh_private_key_path,
+                                instance,
+                                report.Report(command="reconnect"),
+                                cfg.extra_args_ssh_tunnel)
     return report.Report(command="restart")
 
 
+@utils.TimeExecute(function_description="Waiting for AVD to restart")
+def RestartDevice(ssh, instance_id):
+    """Restart AVD with the instance id.
+
+    Args:
+        ssh: Ssh object.
+        instance_id: Integer of the instance id.
+    """
+    ssh_command = "./bin/restart_cvd --instance_num=%d" % (instance_id)
+    try:
+        ssh.Run(ssh_command)
+    except (subprocess.CalledProcessError, errors.DeviceConnectionError) as e:
+        logger.debug(str(e))
+        utils.PrintColorString(str(e), utils.TextColors.FAIL)
+
+
 def Run(args):
     """Run restart.
 
@@ -50,13 +84,12 @@
 
     Returns:
         A Report instance.
-
-    Raises:
-        errors.CommandArgError: Lack the instance_name in args.
     """
     cfg = config.GetAcloudConfig(args)
     if args.instance_name:
         instance = list_instances.GetInstancesFromInstanceNames(
             cfg, [args.instance_name])
-        return RestartFromInstance(instance[0], args.instance_id)
-    raise errors.CommandArgError("Please assign the '--instance-name' in your command.")
+        return RestartFromInstance(cfg, instance[0], args.instance_id)
+    return RestartFromInstance(cfg,
+                               list_instances.ChooseOneRemoteInstance(cfg),
+                               args.instance_id)
diff --git a/restart/restart_test.py b/restart/restart_test.py
new file mode 100644
index 0000000..659bdbe
--- /dev/null
+++ b/restart/restart_test.py
@@ -0,0 +1,55 @@
+# Copyright 2021 - The Android Open Source Project
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+#     http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+"""Tests for restart."""
+import unittest
+
+from unittest import mock
+
+from acloud.internal.lib import driver_test_lib
+from acloud.list import list as list_instances
+from acloud.public import config
+from acloud.restart import restart
+
+
+class RestartTest(driver_test_lib.BaseDriverTest):
+    """Test restart."""
+
+    @mock.patch.object(restart, "RestartFromInstance")
+    def testRun(self, mock_restart):
+        """test Run."""
+        cfg = mock.MagicMock()
+        args = mock.MagicMock()
+        instance_obj = mock.MagicMock()
+        # Test case with provided instance name.
+        args.instance_name = "instance_1"
+        args.instance_id = 1
+        self.Patch(config, "GetAcloudConfig", return_value=cfg)
+        self.Patch(list_instances, "GetInstancesFromInstanceNames",
+                   return_value=[instance_obj])
+        restart.Run(args)
+        mock_restart.assert_has_calls([
+            mock.call(cfg, instance_obj, args.instance_id)])
+
+        # Test case for user select one instance to restart AVD.
+        selected_instance = mock.MagicMock()
+        self.Patch(list_instances, "ChooseOneRemoteInstance",
+                   return_value=selected_instance)
+        args.instance_name = None
+        restart.Run(args)
+        mock_restart.assert_has_calls([
+            mock.call(cfg, selected_instance, args.instance_id)])
+
+
+if __name__ == '__main__':
+    unittest.main()