llvm_tools: Restructure most patch_manager modes

This is an attempt to gradually replace some of the
behavior of patch_manager.py with a more extendable
structure.

Instead of having the same code handle every patch_manager.py
mode, instead we can have a dispatch for supported modes,
and the fallback to legacy when we can't separate the
existing behavior.

This does not change the external API of patch_manager.py
at all. All unittests still pass, and we're still
applying patches correctly as expected.

BUG=b:188465085, b:227216280
TEST=./patch_manager_unittest.py
TEST=cp patch_manager.py patch_utils.py ${llvm_files}/patch_manager/ \
     && sudo emerge llvm

Change-Id: I43d26d4e903140ce2e490624aaac15d0bae898cd
Reviewed-on: https://chromium-review.googlesource.com/c/chromiumos/third_party/toolchain-utils/+/3661358
Tested-by: Jordan Abrahams-Whitehead <ajordanr@google.com>
Reviewed-by: George Burgess <gbiv@chromium.org>
Commit-Queue: Jordan Abrahams-Whitehead <ajordanr@google.com>
diff --git a/llvm_tools/patch_manager.py b/llvm_tools/patch_manager.py
index c755d88..51a7476 100755
--- a/llvm_tools/patch_manager.py
+++ b/llvm_tools/patch_manager.py
@@ -9,22 +9,25 @@
 import dataclasses
 import json
 import os
+from pathlib import Path
 import subprocess
 import sys
 from typing import Any, Dict, IO, List, Optional, Tuple
 
 from failure_modes import FailureModes
 import get_llvm_hash
+import patch_utils
 from subprocess_helpers import check_call
 from subprocess_helpers import check_output
 
-@dataclasses.dataclass
+
+@dataclasses.dataclass(frozen=True)
 class PatchInfo:
   """Holds info for a round of patch applications."""
   # str types are legacy. Patch lists should
   # probably be PatchEntries,
-  applied_patches: List[str]
-  failed_patches: List[str]
+  applied_patches: List[patch_utils.PatchEntry]
+  failed_patches: List[patch_utils.PatchEntry]
   # Can be deleted once legacy code is removed.
   non_applicable_patches: List[str]
   # Can be deleted once legacy code is removed.
@@ -37,6 +40,7 @@
   def _asdict(self):
     return dataclasses.asdict(self)
 
+
 def is_directory(dir_path):
   """Validates that the argument passed into 'argparse' is a directory."""
 
@@ -298,7 +302,14 @@
     raise ValueError('File does not end in ".json": %s' % patch_metadata_file)
 
   with open(patch_metadata_file, 'w') as patch_file:
-    json.dump(patches, patch_file, indent=4, separators=(',', ': '))
+    _WriteJsonChanges(patches, patch_file)
+
+
+def _WriteJsonChanges(patches: List[Dict[str, Any]], file_io: IO[str]):
+  """Write JSON changes to file, does not acquire new file lock."""
+  json.dump(patches, file_io, indent=4, separators=(',', ': '))
+  # Need to add a newline as json.dump omits it.
+  file_io.write('\n')
 
 
 def GetCommitHashesForBisection(src_path, good_svn_version, bad_svn_version):
@@ -382,6 +393,172 @@
   check_output(get_changes_cmd)
 
 
+def ApplyAllFromJson(svn_version: int,
+                     llvm_src_dir: Path,
+                     patches_json_fp: Path,
+                     continue_on_failure: bool = False) -> PatchInfo:
+  """Attempt to apply some patches to a given LLVM source tree.
+
+  This relies on a PATCHES.json file to be the primary way
+  the patches are applied.
+
+  Args:
+    svn_version: LLVM Subversion revision to patch.
+    llvm_src_dir: llvm-project root-level source directory to patch.
+    patches_json_fp: Filepath to the PATCHES.json file.
+    continue_on_failure: Skip any patches which failed to apply,
+      rather than throw an Exception.
+  """
+  with patches_json_fp.open(encoding='utf-8') as f:
+    patches = patch_utils.json_to_patch_entries(patches_json_fp.parent, f)
+  skipped_patches = []
+  failed_patches = []
+  applied_patches = []
+  for pe in patches:
+    applied, failed_hunks = ApplySinglePatchEntry(svn_version, llvm_src_dir,
+                                                  pe)
+    if applied:
+      applied_patches.append(pe)
+      continue
+    if failed_hunks is not None:
+      if continue_on_failure:
+        failed_patches.append(pe)
+        continue
+      else:
+        _PrintFailedPatch(pe, failed_hunks)
+        raise RuntimeError('failed to apply patch '
+                           f'{pe.patch_path()}: {pe.title()}')
+    # Didn't apply, didn't fail, it was skipped.
+    skipped_patches.append(pe)
+  return PatchInfo(
+      non_applicable_patches=skipped_patches,
+      applied_patches=applied_patches,
+      failed_patches=failed_patches,
+      disabled_patches=[],
+      removed_patches=[],
+      modified_metadata=None,
+  )
+
+
+def ApplySinglePatchEntry(
+    svn_version: int, llvm_src_dir: Path, pe: patch_utils.PatchEntry
+) -> Tuple[bool, Optional[Dict[str, List[patch_utils.Hunk]]]]:
+  """Try to apply a single PatchEntry object.
+
+  Returns:
+    Tuple where the first element indicates whether the patch applied,
+    and the second element is a faild hunk mapping from file name to lists of
+    hunks (if the patch didn't apply).
+  """
+  # Don't apply patches outside of the version range.
+  if not pe.can_patch_version(svn_version):
+    return False, None
+  # Test first to avoid making changes.
+  test_application = pe.test_apply(llvm_src_dir)
+  if not test_application:
+    return False, test_application.failed_hunks
+  # Now actually make changes.
+  application_result = pe.apply(llvm_src_dir)
+  if not application_result:
+    # This should be very rare/impossible.
+    return False, application_result.failed_hunks
+  return True, None
+
+
+def RemoveOldPatches(svn_version: int, llvm_src_dir: Path,
+                     patches_json_fp: Path):
+  """Remove patches that don't and will never apply for the future.
+
+  Patches are determined to be "old" via the "is_old" method for
+  each patch entry.
+
+  Args:
+    svn_version: LLVM SVN version.
+    llvm_src_dir: LLVM source directory.
+    patches_json_fp: Location to edit patches on.
+  """
+  with patches_json_fp.open(encoding='utf-8') as f:
+    patches_list = json.load(f)
+  patch_entries = (patch_utils.PatchEntry.from_dict(llvm_src_dir, elem)
+                   for elem in patches_list)
+  filtered_entries = [
+      entry.to_dict() for entry in patch_entries
+      if not entry.is_old(svn_version)
+  ]
+  with patch_utils.atomic_write(patches_json_fp, encoding='utf-8') as f:
+    _WriteJsonChanges(filtered_entries, f)
+
+
+def UpdateVersionRanges(svn_version: int, llvm_src_dir: Path,
+                        patches_json_fp: Path):
+  """Reduce the version ranges of failing patches.
+
+  Patches which fail to apply will have their 'version_range.until'
+  field reduced to the passed in svn_version.
+
+  Modifies the contents of patches_json_fp.
+
+  Ars:
+    svn_version: LLVM revision number.
+    llvm_src_dir: llvm-project directory path.
+    patches_json_fp: Filepath to the PATCHES.json file.
+  """
+  if IsGitDirty(llvm_src_dir):
+    raise RuntimeError('Cannot test patch applications, llvm_src_dir is dirty')
+  with patches_json_fp.open(encoding='utf-8') as f:
+    patch_entries = patch_utils.json_to_patch_entries(patches_json_fp.parent,
+                                                      f)
+  modified_entries: List[patch_utils.PatchEntry] = []
+  for pe in patch_entries:
+    test_result = pe.test_apply(llvm_src_dir)
+    if not test_result:
+      pe.version_range['until'] = svn_version
+      modified_entries.append(pe)
+    else:
+      # We have to actually apply the patch so that future patches
+      # will stack properly.
+      if not pe.apply(llvm_src_dir).succeeded:
+        CleanSrcTree(llvm_src_dir)
+        raise RuntimeError('Could not apply patch that dry ran successfully')
+  with patch_utils.atomic_write(patches_json_fp, encoding='utf-8') as f:
+    _WriteJsonChanges([p.to_dict() for p in patch_entries], f)
+  for entry in modified_entries:
+    print(f'Stopped applying {entry.rel_patch_path} ({entry.title()}) '
+          f'for r{svn_version}')
+  CleanSrcTree(llvm_src_dir)
+
+
+def IsGitDirty(git_root_dir: Path) -> bool:
+  """Return whether the given git directory has uncommitted changes."""
+  if not git_root_dir.is_dir():
+    raise ValueError(f'git_root_dir {git_root_dir} is not a directory')
+  cmd = ['git', 'ls-files', '-m', '--other', '--exclude-standard']
+  return (subprocess.run(cmd,
+                         stdout=subprocess.PIPE,
+                         check=True,
+                         cwd=git_root_dir,
+                         encoding='utf-8').stdout != "")
+
+
+def _PrintFailedPatch(pe: patch_utils.PatchEntry,
+                      failed_hunks: Dict[str, List[patch_utils.Hunk]]):
+  """Print information about a single failing PatchEntry.
+
+  Args:
+    pe: A PatchEntry that failed.
+    failed_hunks: Hunks for pe which failed as dict:
+      filepath: [Hunk...]
+  """
+  print(f'Could not apply {pe.rel_patch_path}: {pe.title()}', file=sys.stderr)
+  for fp, hunks in failed_hunks.items():
+    print(f'{fp}:', file=sys.stderr)
+    for h in hunks:
+      print(
+          f'- {pe.rel_patch_path} '
+          f'l:{h.patch_hunk_lineno_begin}...{h.patch_hunk_lineno_end}',
+          file=sys.stderr)
+
+
 def HandlePatches(svn_version,
                   patch_metadata_file,
                   filesdir_path,
@@ -701,24 +878,27 @@
   return patch_info
 
 
-def PrintPatchResults(patch_info):
+def PrintPatchResults(patch_info: PatchInfo):
   """Prints the results of handling the patches of a package.
 
   Args:
     patch_info: A dataclass that has information on the patches.
   """
 
+  def _fmt(patches):
+    return (str(pe.patch_path()) for pe in patches)
+
   if patch_info.applied_patches:
     print('\nThe following patches applied successfully:')
-    print('\n'.join(patch_info.applied_patches))
+    print('\n'.join(_fmt(patch_info.applied_patches)))
 
   if patch_info.failed_patches:
     print('\nThe following patches failed to apply:')
-    print('\n'.join(patch_info.failed_patches))
+    print('\n'.join(_fmt(patch_info.failed_patches)))
 
   if patch_info.non_applicable_patches:
     print('\nThe following patches were not applicable:')
-    print('\n'.join(patch_info.non_applicable_patches))
+    print('\n'.join(_fmt(patch_info.non_applicable_patches)))
 
   if patch_info.modified_metadata:
     print('\nThe patch metadata file %s has been modified' %
@@ -726,7 +906,7 @@
 
   if patch_info.disabled_patches:
     print('\nThe following patches were disabled:')
-    print('\n'.join(patch_info.disabled_patches))
+    print('\n'.join(_fmt(patch_info.disabled_patches)))
 
   if patch_info.removed_patches:
     print('\nThe following patches were removed from the patch metadata file:')
@@ -754,16 +934,43 @@
     # SVN version is not used in determining whether a patch is applicable.
     args_output.svn_version = GetHEADSVNVersion(args_output.src_path)
 
-  # Get the results of handling the patches of the package.
-  patch_info = HandlePatches(args_output.svn_version,
-                             args_output.patch_metadata_file,
-                             args_output.filesdir_path, args_output.src_path,
-                             FailureModes(args_output.failure_mode),
-                             args_output.good_svn_version,
-                             args_output.num_patches_to_iterate,
-                             args_output.continue_bisection)
+  def _apply_all(args):
+    result = ApplyAllFromJson(
+        svn_version=args.svn_version,
+        llvm_src_dir=Path(args.src_path),
+        patches_json_fp=Path(args.patch_metadata_file),
+        continue_on_failure=args.failure_mode == FailureModes.CONTINUE)
+    PrintPatchResults(result)
 
-  PrintPatchResults(patch_info)
+  def _remove(args):
+    RemoveOldPatches(args.svn_version, Path(args.src_path),
+                     Path(args.patch_metadata_file))
+
+  def _disable(args):
+    UpdateVersionRanges(args.svn_version, Path(args.src_path),
+                        Path(args.patch_metadata_file))
+
+  dispatch_table = {
+      FailureModes.FAIL: _apply_all,
+      FailureModes.CONTINUE: _apply_all,
+      FailureModes.REMOVE_PATCHES: _remove,
+      FailureModes.DISABLE_PATCHES: _disable
+  }
+
+  if args_output.failure_mode in dispatch_table:
+    dispatch_table[args_output.failure_mode](args_output)
+  else:
+    # TODO(ajordanr): Legacy mode, remove when dispatch_table
+    # supports bisection.
+    # Get the results of handling the patches of the package.
+    patch_info = HandlePatches(args_output.svn_version,
+                               args_output.patch_metadata_file,
+                               args_output.filesdir_path, args_output.src_path,
+                               FailureModes(args_output.failure_mode),
+                               args_output.good_svn_version,
+                               args_output.num_patches_to_iterate,
+                               args_output.continue_bisection)
+    PrintPatchResults(patch_info)
 
 
 if __name__ == '__main__':
diff --git a/llvm_tools/patch_manager_unittest.py b/llvm_tools/patch_manager_unittest.py
index 452aea3..63d70a5 100755
--- a/llvm_tools/patch_manager_unittest.py
+++ b/llvm_tools/patch_manager_unittest.py
@@ -1,21 +1,21 @@
 #!/usr/bin/env python3
-# -*- coding: utf-8 -*-
 # Copyright 2019 The ChromiumOS Authors. All rights reserved.
 # Use of this source code is governed by a BSD-style license that can be
 # found in the LICENSE file.
 
 """Unit tests when handling patches."""
 
-from __future__ import print_function
-
 import json
 import os
+from pathlib import Path
 import subprocess
+import tempfile
+from typing import Callable
 import unittest
 import unittest.mock as mock
 
-import patch_manager
 from failure_modes import FailureModes
+import patch_manager
 from test_helpers import CallCountsToMockFunctions
 from test_helpers import CreateTemporaryJsonFile
 from test_helpers import WritePrettyJsonFile
@@ -189,6 +189,83 @@
     self.assertEqual(patch_manager.GetPatchMetadata(test_patch),
                      expected_patch_metadata)
 
+  def testRemoveOldPatches(self):
+    """Can remove old patches from PATCHES.json."""
+    one_patch_dict = {
+        'metadata': {
+            'title': '[some label] hello world',
+        },
+        'platforms': [
+            'chromiumos',
+        ],
+        'rel_patch_path': 'x/y/z',
+        'version_range': {
+            'from': 4,
+            'until': 5,
+        }
+    }
+    patches = [
+        one_patch_dict,
+        {
+            **one_patch_dict, 'version_range': {
+                'until': None
+            }
+        },
+        {
+            **one_patch_dict, 'version_range': {
+                'from': 100
+            }
+        },
+        {
+            **one_patch_dict, 'version_range': {
+                'until': 8
+            }
+        },
+    ]
+    cases = [
+        (0, lambda x: self.assertEqual(len(x), 4)),
+        (6, lambda x: self.assertEqual(len(x), 3)),
+        (8, lambda x: self.assertEqual(len(x), 2)),
+        (1000, lambda x: self.assertEqual(len(x), 2)),
+    ]
+
+    def _t(dirname: str, svn_version: int, assertion_f: Callable):
+      json_filepath = Path(dirname) / 'PATCHES.json'
+      with json_filepath.open('w', encoding='utf-8') as f:
+        json.dump(patches, f)
+      patch_manager.RemoveOldPatches(svn_version, Path(), json_filepath)
+      with json_filepath.open('r', encoding='utf-8') as f:
+        result = json.load(f)
+      assertion_f(result)
+
+    with tempfile.TemporaryDirectory(
+        prefix='patch_manager_unittest') as dirname:
+      for r, a in cases:
+        _t(dirname, r, a)
+
+  def testIsGitDirty(self):
+    """Test if a git directory has uncommitted changes."""
+    with tempfile.TemporaryDirectory(
+        prefix='patch_manager_unittest') as dirname:
+      dirpath = Path(dirname)
+
+      def _run_h(cmd):
+        subprocess.run(cmd, cwd=dirpath, stdout=subprocess.DEVNULL, check=True)
+
+      _run_h(['git', 'init'])
+      self.assertFalse(patch_manager.IsGitDirty(dirpath))
+      test_file = dirpath / 'test_file'
+      test_file.touch()
+      self.assertTrue(patch_manager.IsGitDirty(dirpath))
+      _run_h(['git', 'add', '.'])
+      _run_h(['git', 'commit', '-m', 'test'])
+      self.assertFalse(patch_manager.IsGitDirty(dirpath))
+      test_file.touch()
+      self.assertFalse(patch_manager.IsGitDirty(dirpath))
+      with test_file.open('w', encoding='utf-8'):
+        test_file.write_text('abc')
+      self.assertTrue(patch_manager.IsGitDirty(dirpath))
+
   def testFailedToApplyPatchWhenInvalidSrcPathIsPassedIn(self):
     src_path = '/abs/path/to/src'