Rework compat bindings. (#47863)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47863

Test Plan: Imported from OSS

Reviewed By: ngimel

Differential Revision: D25199261

Pulled By: robieta

fbshipit-source-id: 0a4a0409ddb75c1bf66cd31d67b55080227b1679
diff --git a/.gitignore b/.gitignore
index 3d2e85b..d1f0643 100644
--- a/.gitignore
+++ b/.gitignore
@@ -93,6 +93,8 @@
 torch/include/
 torch/share/
 torch/test/
+torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
+torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
 torch/version.py
 # Root level file used in CI to specify certain env configs.
 # E.g., see .circleci/config.yaml
diff --git a/setup.py b/setup.py
index fd777b1..4ff3ef4 100644
--- a/setup.py
+++ b/setup.py
@@ -327,8 +327,16 @@
 
     # Use copies instead of symbolic files.
     # Windows has very poor support for them.
-    sym_files = ['tools/shared/_utils_internal.py']
-    orig_files = ['torch/_utils_internal.py']
+    sym_files = [
+        'tools/shared/_utils_internal.py',
+        'torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h',
+        'torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h',
+    ]
+    orig_files = [
+        'torch/_utils_internal.py',
+        'third_party/valgrind-headers/callgrind.h',
+        'third_party/valgrind-headers/valgrind.h',
+    ]
     for sym_file, orig_file in zip(sym_files, orig_files):
         same = False
         if os.path.exists(sym_file):
@@ -907,6 +915,8 @@
                 'share/cmake/Gloo/*.cmake',
                 'share/cmake/Tensorpipe/*.cmake',
                 'share/cmake/Torch/*.cmake',
+                'utils/benchmark/utils/valgrind_wrapper/*.cpp',
+                'utils/benchmark/utils/valgrind_wrapper/*.h',
             ],
             'caffe2': [
                 'python/serialized_test/data/operator_test/*.zip',
diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py
new file mode 100644
index 0000000..88c0a2b
--- /dev/null
+++ b/torch/utils/benchmark/utils/_stubs.py
@@ -0,0 +1,24 @@
+import sys
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING or sys.version_info >= (3, 8):
+    from typing import Protocol
+else:
+    from typing_extensions import Protocol
+
+
+class CallgrindModuleType(Protocol):
+    """Replicates the valgrind endpoints in `torch._C`.
+
+    These bindings are used to collect Callgrind profiles on earlier versions
+    of PyTorch and will eventually be removed.
+    """
+    __file__: str
+    __name__: str
+
+    def _valgrind_supported_platform(self) -> bool:
+        ...
+
+    def _valgrind_toggle(self) -> None:
+        ...
diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py
new file mode 100644
index 0000000..6e765f1
--- /dev/null
+++ b/torch/utils/benchmark/utils/cpp_jit.py
@@ -0,0 +1,68 @@
+"""JIT C++ strings into executables."""
+import os
+import threading
+from typing import List, Optional
+
+import torch
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType
+from torch.utils import cpp_extension
+
+
+LOCK = threading.Lock()
+SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]
+
+# BACK_TESTING_NOTE:
+#   There are two workflows where this code could be used. One is the obvious
+#   case where someone simply builds or installs PyTorch and uses Timer.
+#   The other is that the entire `torch/utils/benchmark` folder from a CURRENT
+#   PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
+#   source code. This is what we refer to here as "back testing". The rationale
+#   is that we might want to use current tooling to study some aspect of an
+#   earlier version of PyTorch. (e.g. a regression.)
+#
+#   The problem is that Timer relies on several aspects of core PyTorch, namely
+#   some binding functions for Valgrind symbols in `torch._C` and the
+#   `torch.__config__._cxx_flags()` method. If we were to naively copy code
+#   around this wouldn't work as the symbols of interest aren't present in
+#   earlier versions of PyTorch. In order to work around this, we must add back
+#   testing shims. These shims will never activate during normal use, but will
+#   allow Timer to function outside of the "correct" version of PyTorch by
+#   emulating functionality that was added later.
+#
+#   These shims are temporary, and as Timer becomes more integrated with
+#   PyTorch the cost and complexity of such shims will increase. Once back
+#   testing is no longer required (which is to say we have done enough historic
+#   analysis and the shims no longer justify their maintenance and code
+#   complexity costs) back testing paths will be removed.
+
+if hasattr(torch.__config__, "_cxx_flags"):
+    CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
+    if "-g" not in CXX_FLAGS:
+        CXX_FLAGS.append("-g")
+else:
+    # FIXME: Remove when back testing is no longer required.
+    CXX_FLAGS = ["-O2", "-fPIC", "-g"]
+
+EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
+CONDA_PREFIX = os.getenv("CONDA_PREFIX")
+if CONDA_PREFIX is not None:
+    # Load will automatically search /usr/include, but not conda include.
+    EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))
+
+
+COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None
+def get_compat_bindings() -> CallgrindModuleType:
+    with LOCK:
+        global COMPAT_CALLGRIND_BINDINGS
+        if COMPAT_CALLGRIND_BINDINGS is None:
+            COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
+                name="callgrind_bindings",
+                sources=[os.path.join(
+                    SOURCE_ROOT,
+                    "valgrind_wrapper",
+                    "compat_bindings.cpp"
+                )],
+                extra_cflags=CXX_FLAGS,
+                extra_include_paths=EXTRA_INCLUDE_PATHS,
+            )
+    return COMPAT_CALLGRIND_BINDINGS
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
new file mode 100644
index 0000000..b52626f
--- /dev/null
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
@@ -0,0 +1,25 @@
+/* Used to collect profiles of old versions of PyTorch. */
+#include <callgrind.h>
+#include <pybind11/pybind11.h>
+
+
+bool _valgrind_supported_platform() {
+    #if defined(NVALGRIND)
+    return false;
+    #else
+    return true;
+    #endif
+}
+
+void _valgrind_toggle() {
+    #if defined(NVALGRIND)
+    TORCH_CHECK(false, "Valgrind is not supported.");
+    #else
+    CALLGRIND_TOGGLE_COLLECT;
+    #endif
+}
+
+PYBIND11_MODULE(callgrind_bindings, m) {
+    m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
+    m.def("_valgrind_toggle", &_valgrind_toggle);
+}
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py
deleted file mode 100644
index b7404a6..0000000
--- a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""Allow Timer.collect_callgrind to be used on earlier versions of PyTorch
-
-FIXME: Remove this module once we no longer need to back test.
-"""
-import os
-import textwrap
-from typing import List
-
-from torch.utils.cpp_extension import load_inline
-
-
-# load_inline will automatically search /usr/include, but not conda include.
-extra_include_paths: List[str] = []
-conda_prefix = os.getenv("CONDA_PREFIX")
-if conda_prefix is not None:
-    extra_include_paths = [os.path.join(conda_prefix, "include")]
-
-bindings = load_inline(
-    name="callgrind_bindings",
-    cpp_sources=textwrap.dedent("""
-    #include <valgrind/callgrind.h>
-
-    bool _valgrind_supported_platform() {
-        #if defined(NVALGRIND)
-        return false;
-        #else
-        return true;
-        #endif
-    }
-
-    void _valgrind_toggle() {
-        #if defined(NVALGRIND)
-        TORCH_CHECK(false, "Valgrind is not supported.");
-        #else
-        CALLGRIND_TOGGLE_COLLECT;
-        #endif
-    }
-    """),
-    extra_include_paths=extra_include_paths,
-    functions=["_valgrind_supported_platform", "_valgrind_toggle"],
-)
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
index bad9df9..b851367 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -11,13 +11,13 @@
 import sys
 import tempfile
 import textwrap
-from types import ModuleType
 from typing import (
     cast, Any, Callable, DefaultDict, Dict, Generator, List, NamedTuple,
     Optional, Tuple, Union, TYPE_CHECKING)
 
 import torch
-from torch.utils.benchmark.utils import common
+from torch.utils.benchmark.utils import common, cpp_jit
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType
 
 
 __all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
@@ -444,17 +444,14 @@
 
 class _ValgrindWrapper(object):
     def __init__(self) -> None:
-        self._bindings_module: Optional[ModuleType] = None
+        self._bindings_module: Optional[CallgrindModuleType] = None
         if hasattr(torch._C, "_valgrind_supported_platform"):
             self._supported_platform: bool = torch._C._valgrind_supported_platform()
 
         else:
             print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
-            # This import will JIT the Callgrind control bindings, so don't
-            # invoke unless we know we'll need it.
-            from torch.utils.benchmark.utils.valgrind_wrapper.compat_bindings import bindings
-            self._bindings_module = bindings
-            self._supported_platform = bindings._valgrind_supported_platform()
+            self._bindings_module = cpp_jit.get_compat_bindings()
+            self._supported_platform = self._bindings_module._valgrind_supported_platform()
 
         self._commands_available: Dict[str, bool] = {}
         if self._supported_platform:
@@ -643,7 +640,7 @@
         number: int,
         error_log: str,
         stat_log: str,
-        bindings: Optional[ModuleType],
+        bindings: Optional[CallgrindModuleType],
     ) -> str:
         # The naive template looks something like:
         #   "for _ in range({number}): {stmt}"