Add support for omaha in update_device.py

Automatically detect if the device is running an update_engine with
USE_OMAHA or not, and use different command based on that.

Made the http server serving payload a fake omaha server as well.

Added support for overriding the payload public key on the device.

Bug: 27316596
Test: update a device that uses omaha
Change-Id: I230d58ff105189dc8a624d46ee79cbe90c5f09b0
diff --git a/scripts/update_device.py b/scripts/update_device.py
index 63d0fc9..a5be0a5 100755
--- a/scripts/update_device.py
+++ b/scripts/update_device.py
@@ -27,10 +27,17 @@
 import threading
 import zipfile
 
+import update_payload.payload
+
 
 # The path used to store the OTA package when applying the package from a file.
 OTA_PACKAGE_PATH = '/data/ota_package'
 
+# The path to the payload public key on the device.
+PAYLOAD_KEY_PATH = '/etc/update_engine/update-payload-key.pub.pem'
+
+# The port on the device that update_engine should connect to.
+DEVICE_PORT = 1234
 
 def CopyFileObjLength(fsrc, fdst, buffer_size=128 * 1024, copy_length=None):
   """Copy from a file object to another.
@@ -161,6 +168,57 @@
     CopyFileObjLength(f, self.wfile, copy_length=end_range - start_range)
 
 
+  def do_POST(self):  # pylint: disable=invalid-name
+    """Reply with the omaha response xml."""
+    if self.path != '/update':
+      self.send_error(404, 'Unknown request')
+      return
+
+    if not self.serving_payload:
+      self.send_error(500, 'No serving payload set')
+      return
+
+    try:
+      f = open(self.serving_payload, 'rb')
+    except IOError:
+      self.send_error(404, 'File not found')
+      return
+
+    self.send_response(200)
+    self.send_header("Content-type", "text/xml")
+    self.end_headers()
+
+    stat = os.fstat(f.fileno())
+    sha256sum = subprocess.check_output(['sha256sum', self.serving_payload])
+    payload_hash = sha256sum.split()[0]
+    payload = update_payload.Payload(f)
+    payload.Init()
+
+    xml = '''
+        <?xml version="1.0" encoding="UTF-8"?>
+        <response protocol="3.0">
+          <app appid="appid">
+            <updatecheck status="ok">
+              <urls>
+                <url codebase="http://127.0.0.1:%d/"/>
+              </urls>
+              <manifest version="0.0.0.1">
+                <actions>
+                  <action event="install" run="payload"/>
+                  <action event="postinstall" MetadataSize="%d"/>
+                </actions>
+                <packages>
+                  <package hash_sha256="%s" name="payload" size="%d"/>
+                </packages>
+              </manifest>
+            </updatecheck>
+          </app>
+        </response>
+    ''' % (DEVICE_PORT, payload.metadata_size, payload_hash, stat.st_size)
+    self.wfile.write(xml.strip())
+    return
+
+
 class ServerThread(threading.Thread):
   """A thread for serving HTTP requests."""
 
@@ -203,6 +261,12 @@
           '--size=%d' % ota.size, '--headers="%s"' % headers]
 
 
+def OmahaUpdateCommand(omaha_url):
+  """Return the command to run to start the update in a device using Omaha."""
+  return ['update_engine_client', '--update', '--follow',
+          '--omaha_url=%s' % omaha_url]
+
+
 class AdbHost(object):
   """Represents a device connected via ADB."""
 
@@ -235,6 +299,22 @@
     p.wait()
     return p.returncode
 
+  def adb_output(self, command):
+    """Run an ADB command like "adb push" and return the output.
+
+    Args:
+      command: list of strings containing command and arguments to run
+
+    Returns:
+      the program's output as a string.
+
+    Raises:
+      subprocess.CalledProcessError on command exit != 0.
+    """
+    command = self._command_prefix + command
+    logging.info('Running: %s', ' '.join(str(x) for x in command))
+    return subprocess.check_output(command, universal_newlines=True)
+
 
 def main():
   parser = argparse.ArgumentParser(description='Android A/B OTA helper.')
@@ -248,6 +328,8 @@
                       help='The specific device to use.')
   parser.add_argument('--no-verbose', action='store_true',
                       help='Less verbose output')
+  parser.add_argument('--public-key', type=str, default='',
+                      help='Override the public key used to verify payload.')
   args = parser.parse_args()
   logging.basicConfig(
       level=logging.WARNING if args.no_verbose else logging.INFO)
@@ -262,6 +344,9 @@
   # List of commands to perform the update.
   cmds = []
 
+  help_cmd = ['shell', 'su', '0', 'update_engine_client', '--help']
+  use_omaha = 'omaha' in dut.adb_output(help_cmd)
+
   if args.file:
     # Update via pushing a file to /data.
     device_ota_file = os.path.join(OTA_PACKAGE_PATH, 'debug.zip')
@@ -273,16 +358,32 @@
   else:
     # Update via sending the payload over the network with an "adb reverse"
     # command.
-    device_port = 1234
-    payload_url = 'http://127.0.0.1:%d/payload' % device_port
+    payload_url = 'http://127.0.0.1:%d/payload' % DEVICE_PORT
     server_thread = StartServer(args.otafile)
     cmds.append(
-        ['reverse', 'tcp:%d' % device_port, 'tcp:%d' % server_thread.port])
-    finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % device_port])
+        ['reverse', 'tcp:%d' % DEVICE_PORT, 'tcp:%d' % server_thread.port])
+    finalize_cmds.append(['reverse', '--remove', 'tcp:%d' % DEVICE_PORT])
+
+  if args.public_key:
+    payload_key_dir = os.path.dirname(PAYLOAD_KEY_PATH)
+    cmds.append(
+        ['shell', 'su', '0', 'mount', '-t', 'tmpfs', 'tmpfs', payload_key_dir])
+    # Allow adb push to payload_key_dir
+    cmds.append(['shell', 'su', '0', 'chcon', 'u:object_r:shell_data_file:s0',
+                 payload_key_dir])
+    cmds.append(['push', args.public_key, PAYLOAD_KEY_PATH])
+    # Allow update_engine to read it.
+    cmds.append(['shell', 'su', '0', 'chcon', '-R', 'u:object_r:system_file:s0',
+                 payload_key_dir])
+    finalize_cmds.append(['shell', 'su', '0', 'umount', payload_key_dir])
 
   try:
     # The main update command using the configured payload_url.
-    update_cmd = AndroidUpdateCommand(args.otafile, payload_url)
+    if use_omaha:
+      update_cmd = \
+          OmahaUpdateCommand('http://127.0.0.1:%d/update' % DEVICE_PORT)
+    else:
+      update_cmd = AndroidUpdateCommand(args.otafile, payload_url)
     cmds.append(['shell', 'su', '0'] + update_cmd)
 
     for cmd in cmds: