Fix systrace.py to work on Windows.

The previous implementation used the python select command,
which only works on network sockets in Windows.

Fixed CR/LF stripping to work on windows.

Added a progress indicator that shows while data is being collected.

Change-Id: I1b39dde8e73956eef57eebf489e611f23ee6a7c6
diff --git a/systrace.py b/systrace.py
index 4fe549b..792ac9f 100755
--- a/systrace.py
+++ b/systrace.py
@@ -21,12 +21,17 @@
 # pylint: disable=g-bad-import-order,g-import-not-at-top
 import optparse
 import os
+import Queue
 import re
-import select
 import subprocess
+import threading
 import time
 import zlib
 
+# Text that ADB sends, but does not need to be displayed to the user.
+ADB_IGNORE_REGEXP = r'^capturing trace\.\.\. done|^capturing trace\.\.\.'
+# The number of seconds to wait on output from ADB.
+ADB_STDOUT_READ_TIMEOUT = 0.2
 # The adb shell command to initiate a trace.
 ATRACE_BASE_ARGS = ['atrace']
 # If a custom list of categories is not specified, traces will include
@@ -34,6 +39,10 @@
 DEFAULT_CATEGORIES = 'sched gfx view dalvik webview input disk am wm'.split()
 # The command to list trace categories.
 LIST_CATEGORIES_ARGS = ATRACE_BASE_ARGS + ['--list_categories']
+# Minimum number of seconds between displaying status updates.
+MIN_TIME_BETWEEN_STATUS_UPDATES = 0.2
+# ADB sends this text to indicate the beginning of the trace data.
+TRACE_START_REGEXP = r'TRACE\:'
 # Plain-text trace data should always start with this string.
 TRACE_TEXT_HEADER = '# tracer'
 
@@ -57,6 +66,67 @@
     pass
 
 
+class FileReaderThread(threading.Thread):
+  """Reads data from a file/pipe on a worker thread.
+
+  Use the standard threading.Thread object API to start and interact with the
+  thread (start(), join(), etc.).
+  """
+
+  def __init__(self, file_object, output_queue, text_file, chunk_size=-1):
+    """Initializes a FileReaderThread.
+
+    Args:
+      file_object: The file or pipe to read from.
+      output_queue: A Queue.Queue object that will receive the data
+      text_file: If True, the file will be read one line at a time, and
+          chunk_size will be ignored.  If False, line breaks are ignored and
+          chunk_size must be set to a positive integer.
+      chunk_size: When processing a non-text file (text_file = False),
+          chunk_size is the amount of data to copy into the queue with each
+          read operation.  For text files, this parameter is ignored.
+    """
+    threading.Thread.__init__(self)
+    self._file_object = file_object
+    self._output_queue = output_queue
+    self._text_file = text_file
+    self._chunk_size = chunk_size
+    assert text_file or chunk_size > 0
+
+  def run(self):
+    """Overrides Thread's run() function.
+
+    Returns when an EOF is encountered.
+    """
+    if self._text_file:
+      # Read a text file one line at a time.
+      for line in self._file_object:
+        self._output_queue.put(line)
+    else:
+      # Read binary or text data until we get to EOF.
+      while True:
+        chunk = self._file_object.read(self._chunk_size)
+        if not chunk:
+          break
+        self._output_queue.put(chunk)
+
+  def set_chunk_size(self, chunk_size):
+    """Change the read chunk size.
+
+    This function can only be called if the FileReaderThread object was
+    created with an initial chunk_size > 0.
+    Args:
+      chunk_size: the new chunk size for this file.  Must be > 0.
+    """
+    # The chunk size can be changed asynchronously while a file is being read
+    # in a worker thread.  However, type of file can not be changed after the
+    # the FileReaderThread has been created.  These asserts verify that we are
+    # only changing the chunk size, and not the type of file.
+    assert not self._text_file
+    assert chunk_size > 0
+    self._chunk_size = chunk_size
+
+
 def add_adb_serial(adb_command, device_serial):
   if device_serial is not None:
     adb_command.insert(1, device_serial)
@@ -158,6 +228,18 @@
   return []
 
 
+def status_update(last_update_time):
+  current_time = time.time()
+  if (current_time - last_update_time) >= MIN_TIME_BETWEEN_STATUS_UPDATES:
+    # Gathering a trace may take a while.  Keep printing something so users
+    # don't think the script has hung.
+    sys.stdout.write('.')
+    sys.stdout.flush()
+    return current_time
+
+  return last_update_time
+
+
 def parse_options():
   """Parses and checks the command-line options.
 
@@ -267,79 +349,147 @@
   return (tracer_args, expect_trace)
 
 
-def collect_trace_data(tracer_args):
+def collect_trace_data(tracer_args, expect_trace):
   """Invokes and communicates with the trace process.
 
   Args:
     tracer_args: The command-line to execute.
+    expect_trace: True if the given command should return tracing data.
   Returns:
     The captured trace data.
   """
-  adb = subprocess.Popen(tracer_args, stdout=subprocess.PIPE,
-                         stderr=subprocess.PIPE)
-
-  result = None
-  data = []
-
-  # Read the text portion of the output and watch for the 'TRACE:' marker that
-  # indicates the start of the trace data.
-  while result is None:
-    ready = select.select([adb.stdout, adb.stderr], [],
-                          [adb.stdout, adb.stderr])
-    if adb.stderr in ready[0]:
-      err = os.read(adb.stderr.fileno(), 4096)
-      sys.stderr.write(err)
-      sys.stderr.flush()
-    if adb.stdout in ready[0]:
-      out = os.read(adb.stdout.fileno(), 4096)
-      parts = out.split('\nTRACE:', 1)
-
-      txt = parts[0].replace('\r', '')
-      if len(parts) == 2:
-        # The '\nTRACE:' match stole the last newline from the text, so add it
-        # back here.
-        txt += '\n'
-      sys.stdout.write(txt)
-      sys.stdout.flush()
-
-      if len(parts) == 2:
-        data.append(parts[1])
-        sys.stdout.write('downloading trace...')
-        sys.stdout.flush()
-        break
-
-    result = adb.poll()
-
-  # Read and buffer the data portion of the output.
-  while True:
-    ready = select.select([adb.stdout, adb.stderr], [],
-                          [adb.stdout, adb.stderr])
-    keepReading = False
-    if adb.stderr in ready[0]:
-      err = os.read(adb.stderr.fileno(), 4096)
-      if len(err) > 0:
-        keepReading = True
-        sys.stderr.write(err)
-        sys.stderr.flush()
-    if adb.stdout in ready[0]:
-      out = os.read(adb.stdout.fileno(), 4096)
-      if len(out) > 0:
-        keepReading = True
-        data.append(out)
-
-    if result is not None and not keepReading:
-      break
-
-    result = adb.poll()
-
-  if result != 0:
-    print >> sys.stderr, 'adb returned error code %d' % result
+  try:
+    adb = subprocess.Popen(tracer_args, stdout=subprocess.PIPE,
+                           stderr=subprocess.PIPE)
+  except OSError as error:
+    print >> sys.stderr, ('The command "%s" failed with the following error:' %
+                          ' '.join(tracer_args))
+    print >> sys.stderr, '    ', error
     sys.exit(1)
 
-  return data
+  # Read the output from ADB in a worker thread.  This allows us to monitor the
+  # progress of ADB and bail if ADB becomes unresponsive for any reason.
+
+  # Limit the stdout_queue to 128 entries because we will initially be reading
+  # one byte at a time.  When the queue fills up, the reader thread will
+  # block until there is room in the queue.  Once we start downloading the trace
+  # data, we will switch to reading data in larger chunks, and 128 entries
+  # should be plenty for that purpose.
+  stdout_queue = Queue.Queue(maxsize=128)
+  stderr_queue = Queue.Queue()
+
+  if expect_trace:
+    # Use stdout.write() (here and for the rest of this function) instead
+    # of print() to avoid extra newlines.
+    sys.stdout.write('Capturing trace...')
+
+  # Use a chunk_size of 1 for stdout so we can display the output to
+  # the user without waiting for a full line to be sent.
+  stdout_thread = FileReaderThread(adb.stdout, stdout_queue, text_file=False,
+                                   chunk_size=1)
+  stderr_thread = FileReaderThread(adb.stderr, stderr_queue, text_file=True)
+  stdout_thread.start()
+  stderr_thread.start()
+
+  # Holds the trace data returned by ADB.
+  trace_data = []
+  # Keep track of the current line so we can find the TRACE_START_REGEXP.
+  current_line = ''
+  # Set to True once we've received the TRACE_START_REGEXP.
+  reading_trace_data = False
+
+  last_status_update_time = time.time()
+
+  while (stdout_thread.isAlive() or stderr_thread.isAlive() or
+         not stdout_queue.empty() or not stderr_queue.empty()):
+    if expect_trace:
+      last_status_update_time = status_update(last_status_update_time)
+
+    while not stderr_queue.empty():
+      # Pass along errors from adb.
+      line = stderr_queue.get()
+      sys.stderr.write(line)
+
+    # Read stdout from adb.  The loop exits if we don't get any data for
+    # ADB_STDOUT_READ_TIMEOUT seconds.
+    while True:
+      try:
+        chunk = stdout_queue.get(True, ADB_STDOUT_READ_TIMEOUT)
+      except Queue.Empty:
+        # Didn't get any data, so exit the loop to check that ADB is still
+        # alive and print anything sent to stderr.
+        break
+
+      if reading_trace_data:
+        # Save, but don't print, the trace data.
+        trace_data.append(chunk)
+      else:
+        if not expect_trace:
+          sys.stdout.write(chunk)
+        else:
+          # Buffer the output from ADB so we can remove some strings that
+          # don't need to be shown to the user.
+          current_line += chunk
+          if re.match(TRACE_START_REGEXP, current_line):
+            # We are done capturing the trace.
+            sys.stdout.write('Done.\n')
+            # Now we start downloading the trace data.
+            sys.stdout.write('Downloading trace...')
+            current_line = ''
+            # Use a larger chunk size for efficiency since we no longer
+            # need to worry about parsing the stream.
+            stdout_thread.set_chunk_size(4096)
+            reading_trace_data = True
+          elif chunk == '\n' or chunk == '\r':
+            # Remove ADB output that we don't care about.
+            current_line = re.sub(ADB_IGNORE_REGEXP, '', current_line)
+            if len(current_line) > 1:
+              # ADB printed something that we didn't understand, so show it
+              # it to the user (might be helpful for debugging).
+              sys.stdout.write(current_line)
+            # Reset our current line.
+            current_line = ''
+
+  if expect_trace:
+    if reading_trace_data:
+      # Indicate to the user that the data download is complete.
+      sys.stdout.write('Done.\n')
+    else:
+      # We didn't receive the trace start tag, so something went wrong.
+      sys.stdout.write('ERROR.\n')
+      # Show any buffered ADB output to the user.
+      current_line = re.sub(ADB_IGNORE_REGEXP, '', current_line)
+      if current_line:
+        sys.stdout.write(current_line)
+        sys.stdout.write('\n')
+
+  # The threads should already have stopped, so this is just for cleanup.
+  stdout_thread.join()
+  stderr_thread.join()
+
+  adb.stdout.close()
+  adb.stderr.close()
+
+  # The adb process should be done since it's io pipes are closed.  Call
+  # poll() to set the returncode.
+  adb.poll()
+
+  if adb.returncode != 0:
+    print >> sys.stderr, ('The command "%s" returned error code %d.' %
+                          (' '.join(tracer_args), adb.returncode))
+    sys.exit(1)
+
+  return trace_data
 
 
 def extract_thread_list(trace_data):
+  """Removes the thread list from the given trace data.
+
+  Args:
+    trace_data: The raw trace data (before decompression).
+  Returns:
+    A tuple containing the trace data and a map of thread ids to thread names.
+  """
   threads = {}
   parts = re.split('USER +PID +PPID +VSIZE +RSS +WCHAN +PC +NAME',
                    trace_data, 1)
@@ -355,37 +505,47 @@
   return (trace_data, threads)
 
 
-def strip_and_decompress_trace(data, fix_threads):
+def strip_and_decompress_trace(trace_data):
+  """Fixes new-lines and decompresses trace data.
+
+  Args:
+    trace_data: The trace data returned by atrace.
+  Returns:
+    The decompressed trace data.
+  """
   # Collapse CRLFs that are added by adb shell.
-  if data.startswith('\r\n'):
-    data = data.replace('\r\n', '\n')
+  if trace_data.startswith('\r\n'):
+    trace_data = trace_data.replace('\r\n', '\n')
+  elif trace_data.startswith('\r\r\n'):
+    # On windows, adb adds an extra '\r' character for each line.
+    trace_data = trace_data.replace('\r\r\n', '\n')
 
   # Skip the initial newline.
-  data = data[1:]
+  trace_data = trace_data[1:]
 
-  if not data:
-    print >> sys.stderr, ('No data was captured.  Output file was not '
-                          'written.')
-    sys.exit(1)
-
-  # Indicate to the user that the data download is complete.
-  print ' done\n'
-
-  # Extract the thread list dumped by ps.
-  threads = {}
-  if fix_threads:
-    data, threads = extract_thread_list(data)
-
-  if data.startswith(TRACE_TEXT_HEADER):
-    # Plain-text data.
-    out = data
-  else:
+  if not trace_data.startswith(TRACE_TEXT_HEADER):
     # No header found, so assume the data is compressed.
-    out = zlib.decompress(data)
-  return (out, threads)
+    trace_data = zlib.decompress(trace_data)
+
+  # Enforce Unix line-endings.
+  trace_data = trace_data.replace('\r', '')
+
+  # Skip any initial newlines.
+  while trace_data and trace_data[0] == '\n':
+    trace_data = trace_data[1:]
+
+  return trace_data
 
 
 def fix_thread_names(trace_data, thread_names):
+  """Replaces thread ids with their names.
+
+  Args:
+    trace_data: The atrace data.
+    thread_names: A mapping of thread ids to thread names.
+  Returns:
+    The updated trace data.
+  """
   def repl(m):
     tid = int(m.group(2))
     if tid > 0:
@@ -404,10 +564,22 @@
 
 
 def preprocess_trace_data(options, trace_data):
+  """Performs various processing on atrace data.
+
+  Args:
+    options: The command-line options passed to this script.
+    trace_data: The raw trace data.
+  Returns:
+    The processed trace data.
+  """
   trace_data = ''.join(trace_data)
 
-  trace_data, thread_names = strip_and_decompress_trace(trace_data,
-                                                        options.fix_threads)
+  if options.fix_threads:
+    # Extract the thread list dumped by ps.
+    trace_data, thread_names = extract_thread_list(trace_data)
+
+  if trace_data:
+    trace_data = strip_and_decompress_trace(trace_data)
 
   if not trace_data:
     print >> sys.stderr, ('No data was captured.  Output file was not '
@@ -424,6 +596,13 @@
 
 
 def write_trace_html(html_filename, script_dir, trace_data):
+  """Writes out a trace html file.
+
+  Args:
+    html_filename: The name of the file to write.
+    script_dir: The directory containing this script.
+    trace_data: The atrace data.
+  """
   html_prefix = read_asset(script_dir, 'prefix.html')
   html_suffix = read_asset(script_dir, 'suffix.html')
   trace_viewer_html = read_asset(script_dir, 'systrace_trace_viewer.html')
@@ -455,7 +634,7 @@
   options, categories = parse_options()
   tracer_args, expect_trace = construct_trace_command(options, categories)
 
-  trace_data = collect_trace_data(tracer_args)
+  trace_data = collect_trace_data(tracer_args, expect_trace)
 
   if not expect_trace:
     # Nothing more to do.