Definition infrastructure for instruction count ubenchmarks (#53293)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53293
Instruction count benchmarks need some includes for IValues, but this is also just generally useful. (Unlike Python where you can just drop imports anywhere, C++ will get very upset if you `#include` in a function body...)
Test Plan: Imported from OSS
Reviewed By: Chillee
Differential Revision: D26906684
Pulled By: robieta
fbshipit-source-id: cbdfd79d3b8383100ff2e6857b6f309c387cbe2a
diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py
index e15a0fe..7a653c3 100644
--- a/test/benchmark_utils/test_benchmark_utils.py
+++ b/test/benchmark_utils/test_benchmark_utils.py
@@ -167,8 +167,15 @@
@unittest.skipIf(IS_SANDCASTLE, "C++ timing is OSS only.")
def test_cpp_timer(self):
timer = benchmark_utils.Timer(
- "torch::Tensor y = x + 1;",
+ """
+ #ifndef TIMER_GLOBAL_CHECK
+ static_assert(false);
+ #endif
+
+ torch::Tensor y = x + 1;
+ """,
setup="torch::Tensor x = torch::empty({1});",
+ global_setup="#define TIMER_GLOBAL_CHECK",
timer=timeit.default_timer,
language=benchmark_utils.Language.CPP,
)
diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py
index e2ab6ec..0b80a08 100644
--- a/torch/utils/benchmark/utils/_stubs.py
+++ b/torch/utils/benchmark/utils/_stubs.py
@@ -15,7 +15,8 @@
stmt: str,
setup: str,
timer: Callable[[], float],
- globals: Dict[str, Any]
+ globals: Dict[str, Any],
+ **kwargs: Any,
) -> None:
...
diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py
index 6ad4016..758fd7c 100644
--- a/torch/utils/benchmark/utils/common.py
+++ b/torch/utils/benchmark/utils/common.py
@@ -28,6 +28,7 @@
"""Container for information used to define a Timer. (except globals)"""
stmt: str
setup: str
+ global_setup: str = ""
label: Optional[str] = None
sub_label: Optional[str] = None
description: Optional[str] = None
diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py
index 9cd3e17..9160b8f 100644
--- a/torch/utils/benchmark/utils/cpp_jit.py
+++ b/torch/utils/benchmark/utils/cpp_jit.py
@@ -98,8 +98,16 @@
return COMPAT_CALLGRIND_BINDINGS
-def _compile_template(stmt: str, setup: str, src: str, is_standalone: bool) -> Any:
+def _compile_template(
+ *,
+ stmt: str,
+ setup: str,
+ global_setup: str,
+ src: str,
+ is_standalone: bool
+) -> Any:
for before, after, indentation in (
+ ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0),
("// SETUP_TEMPLATE_LOCATION", setup, 4),
("// STMT_TEMPLATE_LOCATION", stmt, 8)
):
@@ -140,21 +148,21 @@
)
-def compile_timeit_template(stmt: str, setup: str) -> TimeitModuleType:
+def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
with open(template_path, "rt") as f:
src: str = f.read()
- module = _compile_template(stmt, setup, src, is_standalone=False)
+ module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
assert isinstance(module, TimeitModuleType)
return module
-def compile_callgrind_template(stmt: str, setup: str) -> str:
+def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
with open(template_path, "rt") as f:
src: str = f.read()
- target = _compile_template(stmt, setup, src, is_standalone=True)
+ target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
assert isinstance(target, str)
return target
diff --git a/torch/utils/benchmark/utils/timeit_template.cpp b/torch/utils/benchmark/utils/timeit_template.cpp
index 01d62ef..fd46dcb 100644
--- a/torch/utils/benchmark/utils/timeit_template.cpp
+++ b/torch/utils/benchmark/utils/timeit_template.cpp
@@ -1,6 +1,7 @@
/* C++ template for Timer.timeit
This template will be consumed by `cpp_jit.py`, and will replace:
+ `GLOBAL_SETUP_TEMPLATE_LOCATION`,
`SETUP_TEMPLATE_LOCATION`
and
`STMT_TEMPLATE_LOCATION`
@@ -11,6 +12,8 @@
#include <pybind11/pybind11.h>
#include <torch/extension.h>
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
double timeit(int n) {
// Setup
diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py
index f7001b6..371bf18 100644
--- a/torch/utils/benchmark/utils/timer.py
+++ b/torch/utils/benchmark/utils/timer.py
@@ -32,6 +32,7 @@
self,
stmt: str,
setup: str,
+ global_setup: str,
timer: Callable[[], float],
globals: Dict[str, Any],
) -> None:
@@ -50,13 +51,15 @@
self._stmt: str = textwrap.dedent(stmt)
self._setup: str = textwrap.dedent(setup)
+ self._global_setup: str = textwrap.dedent(global_setup)
self._timeit_module: Optional[TimeitModuleType] = None
def timeit(self, number: int) -> float:
if self._timeit_module is None:
self._timeit_module = cpp_jit.compile_timeit_template(
- self._stmt,
- self._setup,
+ stmt=self._stmt,
+ setup=self._setup,
+ global_setup=self._global_setup,
)
return self._timeit_module.timeit(number)
@@ -111,6 +114,10 @@
setup: Optional setup code. Used to define variables used in `stmt`
+ global_setup: (C++ only)
+ Code which is placed at the top level of the file for things like
+ `#include` statements.
+
timer:
Callable which returns the current time. If PyTorch was built
without CUDA or there is no GPU present, this defaults to
@@ -172,6 +179,7 @@
self,
stmt: str = "pass",
setup: str = "pass",
+ global_setup: str = "",
timer: Callable[[], float] = timer,
globals: Optional[Dict[str, Any]] = None,
label: Optional[str] = None,
@@ -187,16 +195,24 @@
# We copy `globals` to prevent mutations from leaking.
# (For instance, `eval` adds the `__builtins__` key)
self._globals = dict(globals or {})
+
+ timer_kwargs = {}
if language in (Language.PYTHON, "py", "python"):
# Include `torch` if not specified as a convenience feature.
self._globals.setdefault("torch", torch)
self._language: Language = Language.PYTHON
+ if global_setup:
+ raise ValueError(
+ f"global_setup is C++ only, got `{global_setup}`. Most "
+ "likely this code can simply be moved to `setup`."
+ )
elif language in (Language.CPP, "cpp", "c++"):
assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped."
self._timer_cls = CPPTimer
setup = ("" if setup == "pass" else setup)
self._language = Language.CPP
+ timer_kwargs["global_setup"] = global_setup
else:
raise ValueError(f"Invalid language `{language}`.")
@@ -222,10 +238,12 @@
setup=setup,
timer=timer,
globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),
+ **timer_kwargs,
)
self._task_spec = common.TaskSpec(
stmt=stmt,
setup=setup,
+ global_setup=global_setup,
label=label,
sub_label=sub_label,
description=description,
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
index a64484f..e14b347 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
@@ -1,6 +1,7 @@
/* C++ template for Timer.collect_callgrind
This template will be consumed by `cpp_jit.py`, and will replace:
+ `GLOBAL_SETUP_TEMPLATE_LOCATION`,
`SETUP_TEMPLATE_LOCATION`
and
`STMT_TEMPLATE_LOCATION`
@@ -12,10 +13,14 @@
#include <callgrind.h>
#include <torch/torch.h>
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
+
#if defined(NVALGRIND)
static_assert(false);
#endif
+
int main(int argc, char* argv[]) {
// This file should only be called inside of `Timer`, so we can adopt a
// very simple and rigid argument parsing scheme.
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
index 5081325..1f6ce28 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -618,8 +618,9 @@
run_loop_cmd = ["python", script_file]
else:
run_loop_exec = cpp_jit.compile_callgrind_template(
- task_spec.stmt,
- task_spec.setup,
+ stmt=task_spec.stmt,
+ setup=task_spec.setup,
+ global_setup=task_spec.global_setup,
)
run_loop_cmd = [
run_loop_exec,