Refactored main().

Split main into smaller functions for improved readability.
This is in preparation for adding support for Windows.

Change-Id: I4786821b0350fd937e4b5d21a4ced7b7e154edc1
diff --git a/systrace.py b/systrace.py
index f99ccaa..0ac116e 100755
--- a/systrace.py
+++ b/systrace.py
@@ -158,13 +158,13 @@
   return []
 
 
-def main():
-  device_sdk_version = get_device_sdk_version()
-  if device_sdk_version < 18:
-    legacy_script = os.path.join(os.path.dirname(sys.argv[0]),
-                                 'systrace-legacy.py')
-    os.execv(legacy_script, sys.argv)
+def parse_options():
+  """Parses and checks the command-line options.
 
+  Returns:
+    A tuple containing the options structure and a list of categories to
+    be traced.
+  """
   usage = 'Usage: %prog [options] [category1 [category2 ...]]'
   desc = 'Example: %prog -b 32768 -t 15 gfx input view sched freq'
   parser = optparse.OptionParser(usage=usage, description=desc)
@@ -210,6 +210,26 @@
   if options.link_assets or options.asset_dir != 'trace-viewer':
     parser.error('--link-assets and --asset-dir are deprecated.')
 
+  if (options.trace_time is not None) and (options.trace_time <= 0):
+    parser.error('the trace time must be a positive number')
+
+  if (options.trace_buf_size is not None) and (options.trace_buf_size <= 0):
+    parser.error('the trace buffer size must be a positive number')
+
+  return (options, categories)
+
+
+def construct_trace_command(options, categories):
+  """Builds a command-line used to invoke a trace process.
+
+  Args:
+    options: The command-line options.
+    categories: The trace categories to capture.
+  Returns:
+    A tuple where the first element is an array of command-line arguments, and
+    the second element is a boolean which will be true if the commend will
+    stream trace data.
+  """
   if options.list_categories:
     tracer_args = construct_adb_shell_command(LIST_CATEGORIES_ARGS,
                                               options.device_serial)
@@ -223,17 +243,11 @@
     if options.compress_trace_data:
       atrace_args.extend(['-z'])
 
-    if options.trace_time is not None:
-      if options.trace_time > 0:
-        atrace_args.extend(['-t', str(options.trace_time)])
-      else:
-        parser.error('the trace time must be a positive number')
+    if (options.trace_time is not None) and (options.trace_time > 0):
+      atrace_args.extend(['-t', str(options.trace_time)])
 
-    if options.trace_buf_size is not None:
-      if options.trace_buf_size > 0:
-        atrace_args.extend(['-b', str(options.trace_buf_size)])
-      else:
-        parser.error('the trace buffer size must be a positive number')
+    if (options.trace_buf_size is not None) and (options.trace_buf_size > 0):
+      atrace_args.extend(['-b', str(options.trace_buf_size)])
 
     if options.app_name is not None:
       atrace_args.extend(['-a', options.app_name])
@@ -250,10 +264,17 @@
     tracer_args = construct_adb_shell_command(atrace_args,
                                               options.device_serial)
 
-  script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
+    return (tracer_args, expect_trace)
 
-  html_filename = options.output_file
 
+def collect_trace_data(tracer_args):
+  """Invokes and communicates with the trace process.
+
+  Args:
+    tracer_args: The command-line to execute.
+  Returns:
+    The captured trace data.
+  """
   adb = subprocess.Popen(tracer_args, stdout=subprocess.PIPE,
                          stderr=subprocess.PIPE)
 
@@ -311,86 +332,140 @@
 
     result = adb.poll()
 
-  if result == 0:
-    if expect_trace:
-      data = ''.join(data)
-
-      # Collapse CRLFs that are added by adb shell.
-      if data.startswith('\r\n'):
-        data = data.replace('\r\n', '\n')
-
-      # Skip the initial newline.
-      data = data[1:]
-
-      if not data:
-        print >> sys.stderr, ('No data was captured.  Output file was not '
-                              'written.')
-        sys.exit(1)
-      else:
-        # Indicate to the user that the data download is complete.
-        print ' done\n'
-
-      # Extract the thread list dumped by ps.
-      threads = {}
-      if options.fix_threads:
-        parts = re.split('USER +PID +PPID +VSIZE +RSS +WCHAN +PC +NAME',
-                         data, 1)
-        if len(parts) == 2:
-          data = parts[0]
-          for line in parts[1].splitlines():
-            cols = line.split(None, 8)
-            if len(cols) == 9:
-              tid = int(cols[1])
-              name = cols[8]
-              threads[tid] = name
-
-      if data.startswith(TRACE_TEXT_HEADER):
-        # Plain-text data.
-        out = data
-      else:
-        # No header found, so assume the data is compressed.
-        out = zlib.decompress(data)
-
-      # Preprocess the data.
-      if options.fix_threads:
-        def repl(m):
-          tid = int(m.group(2))
-          if tid > 0:
-            name = threads.get(tid)
-            if name is None:
-              name = m.group(1)
-              if name == '<...>':
-                name = '<' + str(tid) + '>'
-              threads[tid] = name
-            return name + '-' + m.group(2)
-          else:
-            return m.group(0)
-        out = re.sub(r'^\s*(\S+)-(\d+)', repl, out, flags=re.MULTILINE)
-
-      if options.fix_circular:
-        out = fix_circular_traces(out)
-
-      html_prefix = read_asset(script_dir, 'prefix.html')
-      html_suffix = read_asset(script_dir, 'suffix.html')
-      trace_viewer_html = read_asset(script_dir, 'systrace_trace_viewer.html')
-
-      html_file = open(html_filename, 'w')
-      html_file.write(html_prefix.replace('{{SYSTRACE_TRACE_VIEWER_HTML}}',
-                                          trace_viewer_html))
-
-      html_file.write('<!-- BEGIN TRACE -->\n'
-                      '  <script class="trace-data" type="application/text">\n')
-      html_file.write(out)
-      html_file.write('  </script>\n<!-- END TRACE -->\n')
-
-      html_file.write(html_suffix)
-      html_file.close()
-      print '\n    wrote file://%s\n' % os.path.abspath(options.output_file)
-
-  else:  # i.e. result != 0
+  if result != 0:
     print >> sys.stderr, 'adb returned error code %d' % result
     sys.exit(1)
 
+  return data
+
+
+def extract_thread_list(trace_data):
+  threads = {}
+  parts = re.split('USER +PID +PPID +VSIZE +RSS +WCHAN +PC +NAME',
+                   trace_data, 1)
+  if len(parts) == 2:
+    trace_data = parts[0]
+    for line in parts[1].splitlines():
+      cols = line.split(None, 8)
+      if len(cols) == 9:
+        tid = int(cols[1])
+        name = cols[8]
+        threads[tid] = name
+
+  return (trace_data, threads)
+
+
+def strip_and_decompress_trace(data, fix_threads):
+  # Collapse CRLFs that are added by adb shell.
+  if data.startswith('\r\n'):
+    data = data.replace('\r\n', '\n')
+
+  # Skip the initial newline.
+  data = data[1:]
+
+  if not data:
+    print >> sys.stderr, ('No data was captured.  Output file was not '
+                          'written.')
+    sys.exit(1)
+
+  # Indicate to the user that the data download is complete.
+  print ' done\n'
+
+  # Extract the thread list dumped by ps.
+  threads = {}
+  if fix_threads:
+    data, threads = extract_thread_list(data)
+
+  if data.startswith(TRACE_TEXT_HEADER):
+    # Plain-text data.
+    out = data
+  else:
+    # No header found, so assume the data is compressed.
+    out = zlib.decompress(data)
+  return (out, threads)
+
+
+def fix_thread_names(trace_data, thread_names):
+  def repl(m):
+    tid = int(m.group(2))
+    if tid > 0:
+      name = thread_names.get(tid)
+      if name is None:
+        name = m.group(1)
+        if name == '<...>':
+          name = '<' + str(tid) + '>'
+        thread_names[tid] = name
+      return name + '-' + m.group(2)
+    else:
+      return m.group(0)
+  trace_data = re.sub(r'^\s*(\S+)-(\d+)', repl, trace_data,
+                      flags=re.MULTILINE)
+  return trace_data
+
+
+def preprocess_trace_data(options, trace_data):
+  trace_data = ''.join(trace_data)
+
+  trace_data, thread_names = strip_and_decompress_trace(trace_data,
+                                                        options.fix_threads)
+
+  if not trace_data:
+    print >> sys.stderr, ('No data was captured.  Output file was not '
+                          'written.')
+    sys.exit(1)
+
+  if options.fix_threads:
+    trace_data = fix_thread_names(trace_data, thread_names)
+
+  if options.fix_circular:
+    trace_data = fix_circular_traces(trace_data)
+
+  return trace_data
+
+
+def write_trace_html(html_filename, script_dir, trace_data):
+  html_prefix = read_asset(script_dir, 'prefix.html')
+  html_suffix = read_asset(script_dir, 'suffix.html')
+  trace_viewer_html = read_asset(script_dir, 'systrace_trace_viewer.html')
+
+  # Open the file in binary mode to prevent python from changing the
+  # line endings.
+  html_file = open(html_filename, 'wb')
+  html_file.write(html_prefix.replace('{{SYSTRACE_TRACE_VIEWER_HTML}}',
+                                      trace_viewer_html))
+
+  html_file.write('<!-- BEGIN TRACE -->\n'
+                  '  <script class="trace-data" type="application/text">\n')
+  html_file.write(trace_data)
+  html_file.write('  </script>\n<!-- END TRACE -->\n')
+
+  html_file.write(html_suffix)
+  html_file.close()
+  print '\n    wrote file://%s\n' % os.path.abspath(html_filename)
+
+
+def main():
+  device_sdk_version = get_device_sdk_version()
+  if device_sdk_version < 18:
+    legacy_script = os.path.join(os.path.dirname(sys.argv[0]),
+                                 'systrace-legacy.py')
+    # execv() does not return.
+    os.execv(legacy_script, sys.argv)
+
+  options, categories = parse_options()
+  tracer_args, expect_trace = construct_trace_command(options, categories)
+
+  trace_data = collect_trace_data(tracer_args)
+
+  if not expect_trace:
+    # Nothing more to do.
+    return
+
+  trace_data = preprocess_trace_data(options, trace_data)
+
+  script_dir = os.path.dirname(os.path.abspath(sys.argv[0]))
+  write_trace_html(options.output_file, script_dir, trace_data)
+
 
 def read_asset(src_dir, filename):
   return open(os.path.join(src_dir, filename)).read()