Rework compat bindings. (#47863)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/47863
Test Plan: Imported from OSS
Reviewed By: ngimel
Differential Revision: D25199261
Pulled By: robieta
fbshipit-source-id: 0a4a0409ddb75c1bf66cd31d67b55080227b1679
diff --git a/.gitignore b/.gitignore
index 3d2e85b..d1f0643 100644
--- a/.gitignore
+++ b/.gitignore
@@ -93,6 +93,8 @@
torch/include/
torch/share/
torch/test/
+torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h
+torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h
torch/version.py
# Root level file used in CI to specify certain env configs.
# E.g., see .circleci/config.yaml
diff --git a/setup.py b/setup.py
index fd777b1..4ff3ef4 100644
--- a/setup.py
+++ b/setup.py
@@ -327,8 +327,16 @@
# Use copies instead of symbolic files.
# Windows has very poor support for them.
- sym_files = ['tools/shared/_utils_internal.py']
- orig_files = ['torch/_utils_internal.py']
+ sym_files = [
+ 'tools/shared/_utils_internal.py',
+ 'torch/utils/benchmark/utils/valgrind_wrapper/callgrind.h',
+ 'torch/utils/benchmark/utils/valgrind_wrapper/valgrind.h',
+ ]
+ orig_files = [
+ 'torch/_utils_internal.py',
+ 'third_party/valgrind-headers/callgrind.h',
+ 'third_party/valgrind-headers/valgrind.h',
+ ]
for sym_file, orig_file in zip(sym_files, orig_files):
same = False
if os.path.exists(sym_file):
@@ -907,6 +915,8 @@
'share/cmake/Gloo/*.cmake',
'share/cmake/Tensorpipe/*.cmake',
'share/cmake/Torch/*.cmake',
+ 'utils/benchmark/utils/valgrind_wrapper/*.cpp',
+ 'utils/benchmark/utils/valgrind_wrapper/*.h',
],
'caffe2': [
'python/serialized_test/data/operator_test/*.zip',
diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py
new file mode 100644
index 0000000..88c0a2b
--- /dev/null
+++ b/torch/utils/benchmark/utils/_stubs.py
@@ -0,0 +1,24 @@
+import sys
+from typing import TYPE_CHECKING
+
+
+if TYPE_CHECKING or sys.version_info >= (3, 8):
+ from typing import Protocol
+else:
+ from typing_extensions import Protocol
+
+
+class CallgrindModuleType(Protocol):
+ """Replicates the valgrind endpoints in `torch._C`.
+
+ These bindings are used to collect Callgrind profiles on earlier versions
+ of PyTorch and will eventually be removed.
+ """
+ __file__: str
+ __name__: str
+
+ def _valgrind_supported_platform(self) -> bool:
+ ...
+
+ def _valgrind_toggle(self) -> None:
+ ...
diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py
new file mode 100644
index 0000000..6e765f1
--- /dev/null
+++ b/torch/utils/benchmark/utils/cpp_jit.py
@@ -0,0 +1,68 @@
+"""JIT C++ strings into executables."""
+import os
+import threading
+from typing import List, Optional
+
+import torch
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType
+from torch.utils import cpp_extension
+
+
+LOCK = threading.Lock()
+SOURCE_ROOT = os.path.split(os.path.abspath(__file__))[0]
+
+# BACK_TESTING_NOTE:
+# There are two workflows where this code could be used. One is the obvious
+# case where someone simply builds or installs PyTorch and uses Timer.
+# The other is that the entire `torch/utils/benchmark` folder from a CURRENT
+# PyTorch checkout is copy-pasted into a much OLDER version of the PyTorch
+# source code. This is what we refer to here as "back testing". The rationale
+# is that we might want to use current tooling to study some aspect of an
+# earlier version of PyTorch. (e.g. a regression.)
+#
+# The problem is that Timer relies on several aspects of core PyTorch, namely
+# some binding functions for Valgrind symbols in `torch._C` and the
+# `torch.__config__._cxx_flags()` method. If we were to naively copy code
+# around this wouldn't work as the symbols of interest aren't present in
+# earlier versions of PyTorch. In order to work around this, we must add back
+# testing shims. These shims will never activate during normal use, but will
+# allow Timer to function outside of the "correct" version of PyTorch by
+# emulating functionality that was added later.
+#
+# These shims are temporary, and as Timer becomes more integrated with
+# PyTorch the cost and complexity of such shims will increase. Once back
+# testing is no longer required (which is to say we have done enough historic
+# analysis and the shims no longer justify their maintenance and code
+# complexity costs) back testing paths will be removed.
+
+if hasattr(torch.__config__, "_cxx_flags"):
+ CXX_FLAGS = torch.__config__._cxx_flags().strip().split()
+ if "-g" not in CXX_FLAGS:
+ CXX_FLAGS.append("-g")
+else:
+ # FIXME: Remove when back testing is no longer required.
+ CXX_FLAGS = ["-O2", "-fPIC", "-g"]
+
+EXTRA_INCLUDE_PATHS: List[str] = [os.path.join(SOURCE_ROOT, "valgrind_wrapper")]
+CONDA_PREFIX = os.getenv("CONDA_PREFIX")
+if CONDA_PREFIX is not None:
+ # Load will automatically search /usr/include, but not conda include.
+ EXTRA_INCLUDE_PATHS.append(os.path.join(CONDA_PREFIX, "include"))
+
+
+COMPAT_CALLGRIND_BINDINGS: Optional[CallgrindModuleType] = None
+def get_compat_bindings() -> CallgrindModuleType:
+ with LOCK:
+ global COMPAT_CALLGRIND_BINDINGS
+ if COMPAT_CALLGRIND_BINDINGS is None:
+ COMPAT_CALLGRIND_BINDINGS = cpp_extension.load(
+ name="callgrind_bindings",
+ sources=[os.path.join(
+ SOURCE_ROOT,
+ "valgrind_wrapper",
+ "compat_bindings.cpp"
+ )],
+ extra_cflags=CXX_FLAGS,
+ extra_include_paths=EXTRA_INCLUDE_PATHS,
+ )
+ return COMPAT_CALLGRIND_BINDINGS
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
new file mode 100644
index 0000000..b52626f
--- /dev/null
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.cpp
@@ -0,0 +1,25 @@
+/* Used to collect profiles of old versions of PyTorch. */
+#include <callgrind.h>
+#include <pybind11/pybind11.h>
+
+
+bool _valgrind_supported_platform() {
+ #if defined(NVALGRIND)
+ return false;
+ #else
+ return true;
+ #endif
+}
+
+void _valgrind_toggle() {
+ #if defined(NVALGRIND)
+ TORCH_CHECK(false, "Valgrind is not supported.");
+ #else
+ CALLGRIND_TOGGLE_COLLECT;
+ #endif
+}
+
+PYBIND11_MODULE(callgrind_bindings, m) {
+ m.def("_valgrind_supported_platform", &_valgrind_supported_platform);
+ m.def("_valgrind_toggle", &_valgrind_toggle);
+}
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py b/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py
deleted file mode 100644
index b7404a6..0000000
--- a/torch/utils/benchmark/utils/valgrind_wrapper/compat_bindings.py
+++ /dev/null
@@ -1,41 +0,0 @@
-"""Allow Timer.collect_callgrind to be used on earlier versions of PyTorch
-
-FIXME: Remove this module once we no longer need to back test.
-"""
-import os
-import textwrap
-from typing import List
-
-from torch.utils.cpp_extension import load_inline
-
-
-# load_inline will automatically search /usr/include, but not conda include.
-extra_include_paths: List[str] = []
-conda_prefix = os.getenv("CONDA_PREFIX")
-if conda_prefix is not None:
- extra_include_paths = [os.path.join(conda_prefix, "include")]
-
-bindings = load_inline(
- name="callgrind_bindings",
- cpp_sources=textwrap.dedent("""
- #include <valgrind/callgrind.h>
-
- bool _valgrind_supported_platform() {
- #if defined(NVALGRIND)
- return false;
- #else
- return true;
- #endif
- }
-
- void _valgrind_toggle() {
- #if defined(NVALGRIND)
- TORCH_CHECK(false, "Valgrind is not supported.");
- #else
- CALLGRIND_TOGGLE_COLLECT;
- #endif
- }
- """),
- extra_include_paths=extra_include_paths,
- functions=["_valgrind_supported_platform", "_valgrind_toggle"],
-)
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
index bad9df9..b851367 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -11,13 +11,13 @@
import sys
import tempfile
import textwrap
-from types import ModuleType
from typing import (
cast, Any, Callable, DefaultDict, Dict, Generator, List, NamedTuple,
Optional, Tuple, Union, TYPE_CHECKING)
import torch
-from torch.utils.benchmark.utils import common
+from torch.utils.benchmark.utils import common, cpp_jit
+from torch.utils.benchmark.utils._stubs import CallgrindModuleType
__all__ = ["FunctionCount", "FunctionCounts", "CallgrindStats", "CopyIfCallgrind"]
@@ -444,17 +444,14 @@
class _ValgrindWrapper(object):
def __init__(self) -> None:
- self._bindings_module: Optional[ModuleType] = None
+ self._bindings_module: Optional[CallgrindModuleType] = None
if hasattr(torch._C, "_valgrind_supported_platform"):
self._supported_platform: bool = torch._C._valgrind_supported_platform()
else:
print("Callgrind bindings are not present in `torch._C`. JIT-ing bindings.")
- # This import will JIT the Callgrind control bindings, so don't
- # invoke unless we know we'll need it.
- from torch.utils.benchmark.utils.valgrind_wrapper.compat_bindings import bindings
- self._bindings_module = bindings
- self._supported_platform = bindings._valgrind_supported_platform()
+ self._bindings_module = cpp_jit.get_compat_bindings()
+ self._supported_platform = self._bindings_module._valgrind_supported_platform()
self._commands_available: Dict[str, bool] = {}
if self._supported_platform:
@@ -643,7 +640,7 @@
number: int,
error_log: str,
stat_log: str,
- bindings: Optional[ModuleType],
+ bindings: Optional[CallgrindModuleType],
) -> str:
# The naive template looks something like:
# "for _ in range({number}): {stmt}"