Add support for multiple files per trace in ADB Proxy

Required because transaction traces now output multiple files

Test: Make sure we can still get the traces from winscope using the proxy
Change-Id: I9c74aab0c659f09f7c30bda6fba06a646d36c637
diff --git a/tools/winscope/adb_proxy/winscope_proxy.py b/tools/winscope/adb_proxy/winscope_proxy.py
index c2c03a3..46510ad 100755
--- a/tools/winscope/adb_proxy/winscope_proxy.py
+++ b/tools/winscope/adb_proxy/winscope_proxy.py
@@ -38,6 +38,7 @@
 from http import HTTPStatus
 from http.server import HTTPServer, BaseHTTPRequestHandler
 from tempfile import NamedTemporaryFile
+import base64
 
 # CONFIG #
 
@@ -46,7 +47,7 @@
 PORT = 5544
 
 # Keep in sync with WINSCOPE_PROXY_VERSION in Winscope DataAdb.vue
-VERSION = '0.5'
+VERSION = '0.6'
 
 WINSCOPE_VERSION_HEADER = "Winscope-Proxy-Version"
 WINSCOPE_TOKEN_HEADER = "Winscope-Token"
@@ -62,44 +63,79 @@
 log = logging.getLogger("ADBProxy")
 
 
+class File:
+    def __init__(self, file, filetype) -> None:
+        self.file = file
+        self.type = filetype
+
+    def get_filepaths(self, device_id):
+        return [self.file]
+
+    def get_filetype(self):
+        return self.type
+
+
+class FileMatcher:
+    def __init__(self, path, matcher, filetype) -> None:
+        self.path = path
+        self.matcher = matcher
+        self.type = filetype
+
+    def get_filepaths(self, device_id):
+        matchingFiles = call_adb(
+            f"shell su root find {self.path} -name {self.matcher}", device_id)
+
+        return matchingFiles.split('\n')[:-1]
+
+    def get_filetype(self):
+        return self.type
+
+
 class TraceTarget:
     """Defines a single parameter to trace.
 
     Attributes:
-        file: the path on the device the trace results are saved to.
+        file_matchers: the matchers used to identify the paths on the device the trace results are saved to.
         trace_start: command to start the trace from adb shell, must not block.
         trace_stop: command to stop the trace, should block until the trace is stopped.
     """
 
-    def __init__(self, file: str, trace_start: str, trace_stop: str) -> None:
-        self.file = file
+    def __init__(self, files, trace_start: str, trace_stop: str) -> None:
+        if type(files) is not list:
+            files = [files]
+        self.files = files
         self.trace_start = trace_start
         self.trace_stop = trace_stop
 
 
+# Order of files matters as they will be expected in that order and decoded in that order
 TRACE_TARGETS = {
     "window_trace": TraceTarget(
-        "/data/misc/wmtrace/wm_trace.pb",
+        File("/data/misc/wmtrace/wm_trace.pb", "window_trace"),
         'su root cmd window tracing start\necho "WM trace started."',
         'su root cmd window tracing stop >/dev/null 2>&1'
     ),
     "layers_trace": TraceTarget(
-        "/data/misc/wmtrace/layers_trace.pb",
+        File("/data/misc/wmtrace/layers_trace.pb", "layers_trace"),
         'su root service call SurfaceFlinger 1025 i32 1\necho "SF trace started."',
         'su root service call SurfaceFlinger 1025 i32 0 >/dev/null 2>&1'
     ),
     "screen_recording": TraceTarget(
-        "/data/local/tmp/screen.winscope.mp4",
+        File("/data/local/tmp/screen.winscope.mp4", "screen_recording"),
         'screenrecord --bit-rate 8M /data/local/tmp/screen.winscope.mp4 >/dev/null 2>&1 &\necho "ScreenRecorder started."',
         'pkill -l SIGINT screenrecord >/dev/null 2>&1'
     ),
     "transaction": TraceTarget(
-        "/data/misc/wmtrace/transaction_trace.pb",
+        [
+            File("/data/misc/wmtrace/transaction_trace.pb", "transactions"),
+            FileMatcher("/data/misc/wmtrace/", "transaction_merges_*.pb",
+                        "transaction_merges"),
+        ],
         'su root service call SurfaceFlinger 1020 i32 1\necho "SF transactions recording started."',
         'su root service call SurfaceFlinger 1020 i32 0 >/dev/null 2>&1'
     ),
     "proto_log": TraceTarget(
-        "/data/misc/wmtrace/wm_log.pb",
+        File("/data/misc/wmtrace/wm_log.pb", "proto_log"),
         'su root cmd window logging start\necho "WM logging started."',
         'su root cmd window logging stop >/dev/null 2>&1'
     ),
@@ -109,9 +145,10 @@
 class SurfaceFlingerTraceConfig:
     """Handles optional configuration for surfaceflinger traces.
     """
+
     def __init__(self) -> None:
         # default config flags
-        self.flags =  1 << 0 |  1 << 1
+        self.flags = 1 << 0 | 1 << 1
 
     def add(self, config: str) -> None:
         self.flags |= CONFIG_FLAG[config]
@@ -122,12 +159,14 @@
     def command(self) -> str:
         return f'su root service call SurfaceFlinger 1033 i32 {self.flags}'
 
+
 CONFIG_FLAG = {
     "composition": 1 << 2,
     "metadata": 1 << 3,
     "hwc": 1 << 4
 }
 
+
 class DumpTarget:
     """Defines a single parameter to trace.
 
@@ -136,18 +175,20 @@
         dump_command: command to dump state to file.
     """
 
-    def __init__(self, file: str, dump_command: str) -> None:
-        self.file = file
+    def __init__(self, files, dump_command: str) -> None:
+        if type(files) is not list:
+            files = [files]
+        self.files = files
         self.dump_command = dump_command
 
 
 DUMP_TARGETS = {
     "window_dump": DumpTarget(
-        "/data/local/tmp/wm_dump.pb",
+        File("/data/local/tmp/wm_dump.pb", "window_dump"),
         'su root dumpsys window --proto > /data/local/tmp/wm_dump.pb'
     ),
     "layers_dump": DumpTarget(
-        "/data/local/tmp/sf_dump.pb",
+        File("/data/local/tmp/sf_dump.pb", "layers_dump"),
         'su root dumpsys SurfaceFlinger --proto > /data/local/tmp/sf_dump.pb'
     )
 }
@@ -161,18 +202,21 @@
     try:
         with open(WINSCOPE_TOKEN_LOCATION, 'r') as token_file:
             token = token_file.readline()
-            log.debug("Loaded token {} from {}".format(token, WINSCOPE_TOKEN_LOCATION))
+            log.debug("Loaded token {} from {}".format(
+                token, WINSCOPE_TOKEN_LOCATION))
             return token
     except IOError:
         token = secrets.token_hex(32)
         os.makedirs(os.path.dirname(WINSCOPE_TOKEN_LOCATION), exist_ok=True)
         try:
             with open(WINSCOPE_TOKEN_LOCATION, 'w') as token_file:
-                log.debug("Created and saved token {} to {}".format(token, WINSCOPE_TOKEN_LOCATION))
+                log.debug("Created and saved token {} to {}".format(
+                    token, WINSCOPE_TOKEN_LOCATION))
                 token_file.write(token)
             os.chmod(WINSCOPE_TOKEN_LOCATION, 0o600)
         except IOError:
-            log.error("Unable to save persistent token {} to {}".format(token, WINSCOPE_TOKEN_LOCATION))
+            log.error("Unable to save persistent token {} to {}".format(
+                token, WINSCOPE_TOKEN_LOCATION))
         return token
 
 
@@ -189,8 +233,10 @@
     server.send_header('Cache-Control', 'no-cache, no-store, must-revalidate')
     server.send_header('Access-Control-Allow-Origin', '*')
     server.send_header('Access-Control-Allow-Methods', 'POST, GET, OPTIONS')
-    server.send_header('Access-Control-Allow-Headers', WINSCOPE_TOKEN_HEADER + ', Content-Type, Content-Length')
-    server.send_header('Access-Control-Expose-Headers', 'Winscope-Proxy-Version')
+    server.send_header('Access-Control-Allow-Headers',
+                       WINSCOPE_TOKEN_HEADER + ', Content-Type, Content-Length')
+    server.send_header('Access-Control-Expose-Headers',
+                       'Winscope-Proxy-Version')
     server.send_header(WINSCOPE_VERSION_HEADER, VERSION)
     server.end_headers()
 
@@ -230,7 +276,8 @@
 
     def __internal_error(self, error: str):
         log.error("Internal error: " + error)
-        self.request.respond(HTTPStatus.INTERNAL_SERVER_ERROR, error.encode("utf-8"), 'text/txt')
+        self.request.respond(HTTPStatus.INTERNAL_SERVER_ERROR,
+                             error.encode("utf-8"), 'text/txt')
 
     def __bad_token(self):
         log.info("Bad token")
@@ -263,11 +310,15 @@
         log.debug("Call: " + ' '.join(command))
         return subprocess.check_output(command, stderr=subprocess.STDOUT, input=stdin).decode('utf-8')
     except OSError as ex:
-        log.debug('Error executing adb command: {}\n{}'.format(' '.join(command), repr(ex)))
-        raise AdbError('Error executing adb command: {}\n{}'.format(' '.join(command), repr(ex)))
+        log.debug('Error executing adb command: {}\n{}'.format(
+            ' '.join(command), repr(ex)))
+        raise AdbError('Error executing adb command: {}\n{}'.format(
+            ' '.join(command), repr(ex)))
     except subprocess.CalledProcessError as ex:
-        log.debug('Error executing adb command: {}\n{}'.format(' '.join(command), ex.output.decode("utf-8")))
-        raise AdbError('Error executing adb command: adb {}\n{}'.format(params, ex.output.decode("utf-8")))
+        log.debug('Error executing adb command: {}\n{}'.format(
+            ' '.join(command), ex.output.decode("utf-8")))
+        raise AdbError('Error executing adb command: adb {}\n{}'.format(
+            params, ex.output.decode("utf-8")))
 
 
 def call_adb_outfile(params: str, outfile, device: str = None, stdin: bytes = None):
@@ -282,8 +333,10 @@
             raise AdbError('Error executing adb command: adb {}\n'.format(params) + err.decode(
                 'utf-8') + '\n' + outfile.read().decode('utf-8'))
     except OSError as ex:
-        log.debug('Error executing adb command: adb {}\n{}'.format(params, repr(ex)))
-        raise AdbError('Error executing adb command: adb {}\n{}'.format(params, repr(ex)))
+        log.debug('Error executing adb command: adb {}\n{}'.format(
+            params, repr(ex)))
+        raise AdbError(
+            'Error executing adb command: adb {}\n{}'.format(params, repr(ex)))
 
 
 class ListDevicesEndpoint(RequestEndpoint):
@@ -321,33 +374,41 @@
         return json.loads(server.rfile.read(length).decode("utf-8"))
 
 
-class FetchFileEndpoint(DeviceRequestEndpoint):
+class FetchFilesEndpoint(DeviceRequestEndpoint):
     def process_with_device(self, server, path, device_id):
         if len(path) != 1:
             raise BadRequest("File not specified")
         if path[0] in TRACE_TARGETS:
-            file_path = TRACE_TARGETS[path[0]].file
+            files = TRACE_TARGETS[path[0]].files
         elif path[0] in DUMP_TARGETS:
-            file_path = DUMP_TARGETS[path[0]].file
+            files = DUMP_TARGETS[path[0]].files
         else:
             raise BadRequest("Unknown file specified")
 
-        with NamedTemporaryFile() as tmp:
-            log.debug("Fetching file {} from device to {}".format(file_path, tmp.name))
-            call_adb_outfile('exec-out su root cat ' + file_path, tmp, device_id)
-            log.debug("Deleting file {} from device".format(file_path))
-            call_adb('shell su root rm ' + file_path, device_id)
-            server.send_response(HTTPStatus.OK)
-            server.send_header('X-Content-Type-Options', 'nosniff')
-            server.send_header('Content-type', 'application/octet-stream')
-            add_standard_headers(server)
-            log.debug("Uploading file {}".format(tmp.name))
-            while True:
-                buf = tmp.read(1024)
-                if buf:
-                    server.wfile.write(buf)
-                else:
-                    break
+        file_buffers = dict()
+
+        for f in files:
+            file_type = f.get_filetype()
+            file_paths = f.get_filepaths(device_id)
+
+            for file_path in file_paths:
+                with NamedTemporaryFile() as tmp:
+                    log.debug(
+                        f"Fetching file {file_path} from device to {tmp.name}")
+                    call_adb_outfile('exec-out su root cat ' +
+                                     file_path, tmp, device_id)
+                    log.debug(f"Deleting file {file_path} from device")
+                    call_adb('shell su root rm ' + file_path, device_id)
+                    log.debug(f"Uploading file {tmp.name}")
+                    if file_type not in file_buffers:
+                        file_buffers[file_type] = []
+                    buf = base64.encodestring(tmp.read()).decode("utf-8")
+                    file_buffers[file_type].append(buf)
+
+        # server.send_header('X-Content-Type-Options', 'nosniff')
+        # add_standard_headers(server)
+        j = json.dumps(file_buffers)
+        server.respond(HTTPStatus.OK, j.encode("utf-8"), "text/json")
 
 
 def check_root(device_id):
@@ -372,34 +433,41 @@
             self.process = subprocess.Popen(shell, stdout=subprocess.PIPE,
                                             stderr=subprocess.PIPE, stdin=subprocess.PIPE, start_new_session=True)
         except OSError as ex:
-            raise AdbError('Error executing adb command: adb shell\n{}'.format(repr(ex)))
+            raise AdbError(
+                'Error executing adb command: adb shell\n{}'.format(repr(ex)))
 
         super().__init__()
 
     def timeout(self):
         if self.is_alive():
-            log.warning("Keep-alive timeout for trace on {}".format(self._device_id))
+            log.warning(
+                "Keep-alive timeout for trace on {}".format(self._device_id))
             self.end_trace()
             if self._device_id in TRACE_THREADS:
                 TRACE_THREADS.pop(self._device_id)
 
     def reset_timer(self):
-        log.debug("Resetting keep-alive clock for trace on {}".format(self._device_id))
+        log.debug(
+            "Resetting keep-alive clock for trace on {}".format(self._device_id))
         if self._keep_alive_timer:
             self._keep_alive_timer.cancel()
-        self._keep_alive_timer = threading.Timer(KEEP_ALIVE_INTERVAL_S, self.timeout)
+        self._keep_alive_timer = threading.Timer(
+            KEEP_ALIVE_INTERVAL_S, self.timeout)
         self._keep_alive_timer.start()
 
     def end_trace(self):
         if self._keep_alive_timer:
             self._keep_alive_timer.cancel()
-        log.debug("Sending SIGINT to the trace process on {}".format(self._device_id))
+        log.debug("Sending SIGINT to the trace process on {}".format(
+            self._device_id))
         self.process.send_signal(signal.SIGINT)
         try:
-            log.debug("Waiting for trace shell to exit for {}".format(self._device_id))
+            log.debug("Waiting for trace shell to exit for {}".format(
+                self._device_id))
             self.process.wait(timeout=5)
         except TimeoutError:
-            log.debug("TIMEOUT - sending SIGKILL to the trace process on {}".format(self._device_id))
+            log.debug(
+                "TIMEOUT - sending SIGKILL to the trace process on {}".format(self._device_id))
             self.process.kill()
         self.join()
 
@@ -411,8 +479,10 @@
         time.sleep(0.2)
         for i in range(10):
             if call_adb("shell su root cat /data/local/tmp/winscope_status", device=self._device_id) == 'TRACE_OK\n':
-                call_adb("shell su root rm /data/local/tmp/winscope_status", device=self._device_id)
-                log.debug("Trace finished successfully on {}".format(self._device_id))
+                call_adb(
+                    "shell su root rm /data/local/tmp/winscope_status", device=self._device_id)
+                log.debug("Trace finished successfully on {}".format(
+                    self._device_id))
                 self._success = True
                 break
             log.debug("Still waiting for cleanup on {}".format(self._device_id))
@@ -464,8 +534,10 @@
         command = StartTrace.TRACE_COMMAND.format(
             '\n'.join([t.trace_stop for t in requested_traces]),
             '\n'.join([t.trace_start for t in requested_traces]))
-        log.debug("Trace requested for {} with targets {}".format(device_id, ','.join(requested_types)))
-        TRACE_THREADS[device_id] = TraceThread(device_id, command.encode('utf-8'))
+        log.debug("Trace requested for {} with targets {}".format(
+            device_id, ','.join(requested_types)))
+        TRACE_THREADS[device_id] = TraceThread(
+            device_id, command.encode('utf-8'))
         TRACE_THREADS[device_id].start()
         server.respond(HTTPStatus.OK, b'', "text/plain")
 
@@ -478,7 +550,8 @@
             TRACE_THREADS[device_id].end_trace()
 
         success = TRACE_THREADS[device_id].success()
-        out = TRACE_THREADS[device_id].out + b"\n" + TRACE_THREADS[device_id].err
+        out = TRACE_THREADS[device_id].out + \
+            b"\n" + TRACE_THREADS[device_id].err
         command = TRACE_THREADS[device_id].trace_command
         TRACE_THREADS.pop(device_id)
         if success:
@@ -489,13 +562,16 @@
                     "utf-8") + "\n### Command: adb -s {} shell ###\n### Input ###\n".format(device_id) + command.decode(
                     "utf-8"))
 
+
 class ConfigTrace(DeviceRequestEndpoint):
     def process_with_device(self, server, path, device_id):
         try:
             requested_configs = self.get_request(server)
             config = SurfaceFlingerTraceConfig()
             for requested_config in requested_configs:
-                if not config.is_valid(requested_config): raise BadRequest(f"Unsupported config {requested_config}\n")
+                if not config.is_valid(requested_config):
+                    raise BadRequest(
+                        f"Unsupported config {requested_config}\n")
                 config.add(requested_config)
         except KeyError as err:
             raise BadRequest("Unsupported trace target\n" + str(err))
@@ -512,7 +588,8 @@
         log.debug(f"Changing trace config on device {device_id}")
         out, err = process.communicate(command.encode('utf-8'))
         if process.returncode != 0:
-            raise AdbError(f"Error executing command:\n {command}\n\n### OUTPUT ###{out.decode('utf-8')}\n{err.decode('utf-8')}")
+            raise AdbError(
+                f"Error executing command:\n {command}\n\n### OUTPUT ###{out.decode('utf-8')}\n{err.decode('utf-8')}")
         log.debug(f"Changing trace config finished on device {device_id}")
         server.respond(HTTPStatus.OK, b'', "text/plain")
 
@@ -522,7 +599,8 @@
         if device_id not in TRACE_THREADS:
             raise BadRequest("No trace in progress for {}".format(device_id))
         TRACE_THREADS[device_id].reset_timer()
-        server.respond(HTTPStatus.OK, str(TRACE_THREADS[device_id].is_alive()).encode("utf-8"), "text/plain")
+        server.respond(HTTPStatus.OK, str(
+            TRACE_THREADS[device_id].is_alive()).encode("utf-8"), "text/plain")
 
 
 class DumpEndpoint(DeviceRequestEndpoint):
@@ -537,7 +615,7 @@
         if not check_root(device_id):
             raise AdbError(
                 "Unable to acquire root privileges on the device - check the output of 'adb -s {} shell su root id'"
-                    .format(device_id))
+                .format(device_id))
         command = '\n'.join(t.dump_command for t in requested_traces)
         shell = ['adb', '-s', device_id, 'shell']
         log.debug("Starting dump shell {}".format(' '.join(shell)))
@@ -555,13 +633,17 @@
 class ADBWinscopeProxy(BaseHTTPRequestHandler):
     def __init__(self, request, client_address, server):
         self.router = RequestRouter(self)
-        self.router.register_endpoint(RequestType.GET, "devices", ListDevicesEndpoint())
-        self.router.register_endpoint(RequestType.GET, "status", StatusEndpoint())
-        self.router.register_endpoint(RequestType.GET, "fetch", FetchFileEndpoint())
+        self.router.register_endpoint(
+            RequestType.GET, "devices", ListDevicesEndpoint())
+        self.router.register_endpoint(
+            RequestType.GET, "status", StatusEndpoint())
+        self.router.register_endpoint(
+            RequestType.GET, "fetch", FetchFilesEndpoint())
         self.router.register_endpoint(RequestType.POST, "start", StartTrace())
         self.router.register_endpoint(RequestType.POST, "end", EndTrace())
         self.router.register_endpoint(RequestType.POST, "dump", DumpEndpoint())
-        self.router.register_endpoint(RequestType.POST, "configtrace", ConfigTrace())
+        self.router.register_endpoint(
+            RequestType.POST, "configtrace", ConfigTrace())
         super().__init__(request, client_address, server)
 
     def respond(self, code: int, data: bytes, mime: str) -> None: