pw_presubmit: Allow passing stdin args from git hooks

Change-Id: Ifba68421588881eede91bc8eb7388fe8cdc837bd
Reviewed-on: https://pigweed-review.googlesource.com/c/pigweed/pigweed/+/93681
Reviewed-by: Wyatt Hepler <hepler@google.com>
Commit-Queue: Daniel Zheng <dhhzheng@google.com>
diff --git a/pw_presubmit/docs.rst b/pw_presubmit/docs.rst
index c99f496..5519965 100644
--- a/pw_presubmit/docs.rst
+++ b/pw_presubmit/docs.rst
@@ -191,7 +191,7 @@
   from pathlib import Path
   import re
   import sys
-  from typing import List, Pattern
+  from typing import List, Optional, Pattern
 
   try:
       import pw_cli.log
@@ -283,16 +283,34 @@
   PROGRAMS = pw_presubmit.Programs(other=OTHER, quick=QUICK, full=FULL)
 
 
-  def run(install: bool, **presubmit_args) -> int:
+  #
+  # Allowlist of remote refs for presubmit. If the remote ref being pushed to
+  # matches any of these values (with regex matching), then the presubmits
+  # checks will be run before pushing.
+  #
+  PRE_PUSH_REMOTE_REF_ALLOWLIST = (
+      'refs/for/main',
+  )
+
+
+  def run(install: bool, remote_ref: Optional[str],  **presubmit_args) -> int:
       """Process the --install argument then invoke pw_presubmit."""
 
       # Install the presubmit Git pre-push hook, if requested.
       if install:
-          install_hook(__file__, 'pre-push', ['--base', 'HEAD~'],
+          # '$remote_ref' will be replaced by the actual value of the remote ref
+          # at runtime.
+          install_hook(__file__, 'pre-push',
+                       ['--base', 'HEAD~', '--remote-ref', '$remote_ref'],
                        git_repo.root())
           return 0
 
-      return cli.run(root=PROJECT_ROOT, **presubmit_args)
+      # Run the checks if either no remote_ref was passed, or if the remote ref
+      # matches anything in the allowlist.
+      if remote_ref is None or any(
+              re.search(pattern, remote_ref)
+              for pattern in PRE_PUSH_REMOTE_REF_ALLOWLIST):
+          return cli.run(root=PROJECT_ROOT, **presubmit_args)
 
 
   def main() -> int:
@@ -306,6 +324,16 @@
           action='store_true',
           help='Install the presubmit as a Git pre-push hook and exit.')
 
+      # Define an optional flag to pass the remote ref into this script, if it
+      # is run as a pre-push hook. The destination variable in the parsed args
+      # will be `remote_ref`, as dashes are replaced with underscores to make
+      # valid variable names.
+      parser.add_argument(
+          '--remote-ref',
+          default=None,
+          nargs='?',  # Make optional.
+          help='Remote ref of the push command, for use by the pre-push hook.')
+
       return run(**vars(parser.parse_args()))
 
   if __name__ == '__main__':
diff --git a/pw_presubmit/py/pw_presubmit/install_hook.py b/pw_presubmit/py/pw_presubmit/install_hook.py
index 62363f9..2a0ee34 100755
--- a/pw_presubmit/py/pw_presubmit/install_hook.py
+++ b/pw_presubmit/py/pw_presubmit/install_hook.py
@@ -33,11 +33,42 @@
                        stdout=subprocess.PIPE).stdout.strip().decode())
 
 
+def _stdin_args_for_hook(hook) -> Sequence[str]:
+    """Gives stdin arguments for each hook.
+
+    See https://git-scm.com/docs/githooks for more information.
+    """
+    if hook == 'pre-push':
+        return ('local_ref', 'local_object_name', 'remote_ref',
+                'remote_object_name')
+    if hook in ('pre-receive', 'post-receive', 'reference-transaction'):
+        return ('old_value', 'new_value', 'ref_name')
+    if hook == 'post-rewrite':
+        return ('old_object_name', 'new_object_name')
+    return ()
+
+
+def _replace_arg_in_hook(arg: str, unquoted_args: Sequence[str]) -> str:
+    if arg in unquoted_args:
+        return arg
+    return shlex.quote(arg)
+
+
 def install_hook(script,
                  hook: str,
                  args: Sequence[str] = (),
                  repository: Union[Path, str] = '.') -> None:
-    """Installs a simple Git hook that calls a script with arguments."""
+    """Installs a simple Git hook that calls a script with arguments.
+
+    Args:
+      script: Path to the script to run in the hook.
+      hook: Git hook to install, e.g. 'pre-push'.
+      args: Arguments to pass to `script` when it is run in the hook. These will
+        be sanitised with `shlex.quote`, except for any arguments are equal to
+        f'${stdin_arg}' for some `stdin_arg` which matches a standard-input
+        argument to the git hook.
+      repository: Repository to install the hook in.
+    """
     root = git_repo_root(repository).resolve()
     script = os.path.relpath(script, root)
 
@@ -52,7 +83,12 @@
 
     hook_path.parent.mkdir(exist_ok=True)
 
-    command = ' '.join(shlex.quote(arg) for arg in (script, *args))
+    hook_stdin_args = _stdin_args_for_hook(hook)
+    read_stdin_command = 'read ' + ' '.join(hook_stdin_args)
+
+    unquoted_args = [f'${arg}' for arg in hook_stdin_args]
+    script_command = ' '.join(
+        _replace_arg_in_hook(arg, unquoted_args) for arg in (script, *args))
 
     with hook_path.open('w') as file:
         line = lambda *args: print(*args, file=file)
@@ -66,7 +102,10 @@
         line('# submodule hook.')
         line('unset $(git rev-parse --local-env-vars)')
         line()
-        line(command)
+        line('# Read the stdin args for the hook, made available by git.')
+        line(read_stdin_command)
+        line()
+        line(script_command)
 
     hook_path.chmod(0o755)
     logging.info('Created %s hook for %s at %s', hook, script, hook_path)