Definition infrastructure for instruction count ubenchmarks (#53293)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/53293

Instruction count benchmarks need some includes for IValues, but this is also just generally useful. (Unlike Python where you can just drop imports anywhere, C++ will get very upset if you `#include` in a function body...)

Test Plan: Imported from OSS

Reviewed By: Chillee

Differential Revision: D26906684

Pulled By: robieta

fbshipit-source-id: cbdfd79d3b8383100ff2e6857b6f309c387cbe2a
diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py
index e15a0fe..7a653c3 100644
--- a/test/benchmark_utils/test_benchmark_utils.py
+++ b/test/benchmark_utils/test_benchmark_utils.py
@@ -167,8 +167,15 @@
     @unittest.skipIf(IS_SANDCASTLE, "C++ timing is OSS only.")
     def test_cpp_timer(self):
         timer = benchmark_utils.Timer(
-            "torch::Tensor y = x + 1;",
+            """
+                #ifndef TIMER_GLOBAL_CHECK
+                static_assert(false);
+                #endif
+
+                torch::Tensor y = x + 1;
+            """,
             setup="torch::Tensor x = torch::empty({1});",
+            global_setup="#define TIMER_GLOBAL_CHECK",
             timer=timeit.default_timer,
             language=benchmark_utils.Language.CPP,
         )
diff --git a/torch/utils/benchmark/utils/_stubs.py b/torch/utils/benchmark/utils/_stubs.py
index e2ab6ec..0b80a08 100644
--- a/torch/utils/benchmark/utils/_stubs.py
+++ b/torch/utils/benchmark/utils/_stubs.py
@@ -15,7 +15,8 @@
         stmt: str,
         setup: str,
         timer: Callable[[], float],
-        globals: Dict[str, Any]
+        globals: Dict[str, Any],
+        **kwargs: Any,
     ) -> None:
         ...
 
diff --git a/torch/utils/benchmark/utils/common.py b/torch/utils/benchmark/utils/common.py
index 6ad4016..758fd7c 100644
--- a/torch/utils/benchmark/utils/common.py
+++ b/torch/utils/benchmark/utils/common.py
@@ -28,6 +28,7 @@
     """Container for information used to define a Timer. (except globals)"""
     stmt: str
     setup: str
+    global_setup: str = ""
     label: Optional[str] = None
     sub_label: Optional[str] = None
     description: Optional[str] = None
diff --git a/torch/utils/benchmark/utils/cpp_jit.py b/torch/utils/benchmark/utils/cpp_jit.py
index 9cd3e17..9160b8f 100644
--- a/torch/utils/benchmark/utils/cpp_jit.py
+++ b/torch/utils/benchmark/utils/cpp_jit.py
@@ -98,8 +98,16 @@
     return COMPAT_CALLGRIND_BINDINGS
 
 
-def _compile_template(stmt: str, setup: str, src: str, is_standalone: bool) -> Any:
+def _compile_template(
+    *,
+    stmt: str,
+    setup: str,
+    global_setup: str,
+    src: str,
+    is_standalone: bool
+) -> Any:
     for before, after, indentation in (
+        ("// GLOBAL_SETUP_TEMPLATE_LOCATION", global_setup, 0),
         ("// SETUP_TEMPLATE_LOCATION", setup, 4),
         ("// STMT_TEMPLATE_LOCATION", stmt, 8)
     ):
@@ -140,21 +148,21 @@
     )
 
 
-def compile_timeit_template(stmt: str, setup: str) -> TimeitModuleType:
+def compile_timeit_template(*, stmt: str, setup: str, global_setup: str) -> TimeitModuleType:
     template_path: str = os.path.join(SOURCE_ROOT, "timeit_template.cpp")
     with open(template_path, "rt") as f:
         src: str = f.read()
 
-    module = _compile_template(stmt, setup, src, is_standalone=False)
+    module = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=False)
     assert isinstance(module, TimeitModuleType)
     return module
 
 
-def compile_callgrind_template(stmt: str, setup: str) -> str:
+def compile_callgrind_template(*, stmt: str, setup: str, global_setup: str) -> str:
     template_path: str = os.path.join(SOURCE_ROOT, "valgrind_wrapper", "timer_callgrind_template.cpp")
     with open(template_path, "rt") as f:
         src: str = f.read()
 
-    target = _compile_template(stmt, setup, src, is_standalone=True)
+    target = _compile_template(stmt=stmt, setup=setup, global_setup=global_setup, src=src, is_standalone=True)
     assert isinstance(target, str)
     return target
diff --git a/torch/utils/benchmark/utils/timeit_template.cpp b/torch/utils/benchmark/utils/timeit_template.cpp
index 01d62ef..fd46dcb 100644
--- a/torch/utils/benchmark/utils/timeit_template.cpp
+++ b/torch/utils/benchmark/utils/timeit_template.cpp
@@ -1,6 +1,7 @@
 /* C++ template for Timer.timeit
 
 This template will be consumed by `cpp_jit.py`, and will replace:
+    `GLOBAL_SETUP_TEMPLATE_LOCATION`,
     `SETUP_TEMPLATE_LOCATION`
       and
     `STMT_TEMPLATE_LOCATION`
@@ -11,6 +12,8 @@
 #include <pybind11/pybind11.h>
 #include <torch/extension.h>
 
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
 
 double timeit(int n) {
     // Setup
diff --git a/torch/utils/benchmark/utils/timer.py b/torch/utils/benchmark/utils/timer.py
index f7001b6..371bf18 100644
--- a/torch/utils/benchmark/utils/timer.py
+++ b/torch/utils/benchmark/utils/timer.py
@@ -32,6 +32,7 @@
         self,
         stmt: str,
         setup: str,
+        global_setup: str,
         timer: Callable[[], float],
         globals: Dict[str, Any],
     ) -> None:
@@ -50,13 +51,15 @@
 
         self._stmt: str = textwrap.dedent(stmt)
         self._setup: str = textwrap.dedent(setup)
+        self._global_setup: str = textwrap.dedent(global_setup)
         self._timeit_module: Optional[TimeitModuleType] = None
 
     def timeit(self, number: int) -> float:
         if self._timeit_module is None:
             self._timeit_module = cpp_jit.compile_timeit_template(
-                self._stmt,
-                self._setup,
+                stmt=self._stmt,
+                setup=self._setup,
+                global_setup=self._global_setup,
             )
 
         return self._timeit_module.timeit(number)
@@ -111,6 +114,10 @@
 
         setup: Optional setup code. Used to define variables used in `stmt`
 
+        global_setup: (C++ only)
+            Code which is placed at the top level of the file for things like
+            `#include` statements.
+
         timer:
             Callable which returns the current time. If PyTorch was built
             without CUDA or there is no GPU present, this defaults to
@@ -172,6 +179,7 @@
         self,
         stmt: str = "pass",
         setup: str = "pass",
+        global_setup: str = "",
         timer: Callable[[], float] = timer,
         globals: Optional[Dict[str, Any]] = None,
         label: Optional[str] = None,
@@ -187,16 +195,24 @@
         # We copy `globals` to prevent mutations from leaking.
         # (For instance, `eval` adds the `__builtins__` key)
         self._globals = dict(globals or {})
+
+        timer_kwargs = {}
         if language in (Language.PYTHON, "py", "python"):
             # Include `torch` if not specified as a convenience feature.
             self._globals.setdefault("torch", torch)
             self._language: Language = Language.PYTHON
+            if global_setup:
+                raise ValueError(
+                    f"global_setup is C++ only, got `{global_setup}`. Most "
+                    "likely this code can simply be moved to `setup`."
+                )
 
         elif language in (Language.CPP, "cpp", "c++"):
             assert self._timer_cls is timeit.Timer, "_timer_cls has already been swapped."
             self._timer_cls = CPPTimer
             setup = ("" if setup == "pass" else setup)
             self._language = Language.CPP
+            timer_kwargs["global_setup"] = global_setup
 
         else:
             raise ValueError(f"Invalid language `{language}`.")
@@ -222,10 +238,12 @@
             setup=setup,
             timer=timer,
             globals=valgrind_timer_interface.CopyIfCallgrind.unwrap_all(self._globals),
+            **timer_kwargs,
         )
         self._task_spec = common.TaskSpec(
             stmt=stmt,
             setup=setup,
+            global_setup=global_setup,
             label=label,
             sub_label=sub_label,
             description=description,
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
index a64484f..e14b347 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_callgrind_template.cpp
@@ -1,6 +1,7 @@
 /* C++ template for Timer.collect_callgrind
 
 This template will be consumed by `cpp_jit.py`, and will replace:
+    `GLOBAL_SETUP_TEMPLATE_LOCATION`,
     `SETUP_TEMPLATE_LOCATION`
       and
     `STMT_TEMPLATE_LOCATION`
@@ -12,10 +13,14 @@
 #include <callgrind.h>
 #include <torch/torch.h>
 
+// Global setup. (e.g. #includes)
+// GLOBAL_SETUP_TEMPLATE_LOCATION
+
 #if defined(NVALGRIND)
 static_assert(false);
 #endif
 
+
 int main(int argc, char* argv[]) {
     // This file should only be called inside of `Timer`, so we can adopt a
     // very simple and rigid argument parsing scheme.
diff --git a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
index 5081325..1f6ce28 100644
--- a/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
+++ b/torch/utils/benchmark/utils/valgrind_wrapper/timer_interface.py
@@ -618,8 +618,9 @@
                 run_loop_cmd = ["python", script_file]
             else:
                 run_loop_exec = cpp_jit.compile_callgrind_template(
-                    task_spec.stmt,
-                    task_spec.setup,
+                    stmt=task_spec.stmt,
+                    setup=task_spec.setup,
+                    global_setup=task_spec.global_setup,
                 )
                 run_loop_cmd = [
                     run_loop_exec,