Revert "[inductor] Faster C++ kernel python bindings (#117500)"
This reverts commit bb0fd1bd3ca145b77159427bc5bacf5f98ec3896.
Reverted https://github.com/pytorch/pytorch/pull/117500 on behalf of https://github.com/PaliC due to breaking internal discussed with author offline ([comment](https://github.com/pytorch/pytorch/pull/117500#issuecomment-1896516512))
diff --git a/benchmarks/dynamo/microbenchmarks/overheads.py b/benchmarks/dynamo/microbenchmarks/overheads.py
deleted file mode 100644
index 93d1ce6..0000000
--- a/benchmarks/dynamo/microbenchmarks/overheads.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import time
-import timeit
-
-import numpy as np
-
-import torch
-
-
-def add1(x):
- return x + 1
-
-
-def bench(name, fn):
- x = torch.randn(1)
- start = time.perf_counter()
- for _ in range(3):
- fn(x)
- end = time.perf_counter()
-
- results = timeit.repeat(lambda: fn(x), number=1000, repeat=100)
- print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
-
-
-def main():
- bench("eager ", add1)
- bench("compiled", torch.compile(add1))
-
-
-if __name__ == "__main__":
- main()
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 781517c..c7ba622 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -23,7 +23,6 @@
import sys
import sysconfig
import tempfile
-import textwrap
import threading
import warnings
import weakref
@@ -1168,13 +1167,7 @@
def get_shared(shared: bool = True) -> str:
- if not shared:
- return ""
- if platform.system() == "Darwin" and "clang" in cpp_compiler():
- # This causes undefined symbols to behave the same as linux
- return "-shared -fPIC -undefined dynamic_lookup"
- else:
- return "-shared -fPIC"
+ return "-shared -fPIC" if shared else ""
def get_warning_all_flag(warning_all: bool = True) -> str:
@@ -1792,24 +1785,19 @@
class CppCodeCache:
- cache: Dict[str, Union[CDLL, ModuleType]] = {}
+ cache: Dict[str, CDLL] = dict()
clear = staticmethod(cache.clear)
- cpp_compile_command_flags: Dict[str, Any] = {}
@staticmethod
- def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
- return cdll.LoadLibrary(path)
-
- @classmethod
- def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
+ def _load_library(path: str) -> CDLL:
try:
- return cls._load_library_inner(path, key)
- except (ImportError, OSError) as e:
+ return cdll.LoadLibrary(path)
+ except OSError as e:
if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
# hacky workaround for fbcode/buck
global _libgomp
_libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
- return cls._load_library_inner(path, key)
+ return cdll.LoadLibrary(path)
if "failed to map segment from shared object" in str(e):
raise OSError(
f"{e}. The most common reason this may occur is if the {tempfile.gettempdir()} folder "
@@ -1820,13 +1808,9 @@
raise
@classmethod
- def load(cls, source_code: str) -> Union[CDLL, ModuleType]:
+ def load(cls, source_code: str) -> CDLL:
picked_vec_isa = pick_vec_isa()
- cpp_command = repr(
- cpp_compile_command(
- "i", "o", vec_isa=picked_vec_isa, **cls.cpp_compile_command_flags
- )
- )
+ cpp_command = repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa))
key, input_path = write(source_code, "cpp", extra=cpp_command)
if key not in cls.cache:
from filelock import FileLock
@@ -1838,101 +1822,16 @@
if not os.path.exists(output_path):
cmd = shlex.split(
cpp_compile_command(
- input=input_path,
- output=output_path,
- vec_isa=picked_vec_isa,
- **cls.cpp_compile_command_flags,
+ input=input_path, output=output_path, vec_isa=picked_vec_isa
)
)
compile_file(input_path, output_path, cmd)
- cls.cache[key] = cls._load_library(output_path, key)
- cls.cache[key].key = key # type: ignore[union-attr]
+ cls.cache[key] = cls._load_library(output_path)
+ cls.cache[key].key = key # type: ignore[attr-defined]
return cls.cache[key]
-class CppPythonBindingsCodeCache(CppCodeCache):
- cache: Dict[str, Union[CDLL, ModuleType]] = {}
- clear = staticmethod(cache.clear)
- cpp_compile_command_flags = {
- "include_pytorch": True,
- "shared": True,
- }
- suffix_template = textwrap.dedent(
- """
- // Python bindings to call kernel():
- #define PY_SSIZE_T_CLEAN
- #include <Python.h>
-
- // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow
- extern "C" void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj);
-
- template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
- static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
- return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
- }
- template <> inline long parse_arg<long>(PyObject* args, size_t n) {
- auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
- if(result == -1 && PyErr_Occurred())
- [[unlikely]] throw std::runtime_error("expected int arg");
- return result;
- }
-
- static PyObject* kernel_py(PyObject* self, PyObject* args) {
- try {
- if(!PyTuple_CheckExact(args))
- [[unlikely]] throw std::runtime_error("tuple args required");
- if(PyTuple_GET_SIZE(args) != %s)
- [[unlikely]] throw std::runtime_error("requires %s args");
- kernel(%s);
- Py_RETURN_NONE;
- } catch(std::exception const& e) {
- PyErr_SetString(PyExc_RuntimeError, e.what());
- return nullptr;
- }
- }
-
- static PyMethodDef py_methods[] = {
- {"kernel", kernel_py, METH_VARARGS, ""},
- {NULL, NULL, 0, NULL}};
-
- static struct PyModuleDef py_module =
- {PyModuleDef_HEAD_INIT, "kernel", NULL, -1, py_methods};
-
- PyMODINIT_FUNC PyInit_kernel(void) {
- return PyModule_Create(&py_module);
- }
- """
- )
-
- @classmethod
- def _load_library_inner(cls, path: str, key: str) -> ModuleType:
- return importlib.machinery.ExtensionFileLoader(
- f"{key}.kernel", path
- ).load_module() # type: ignore[call-arg]
-
- @classmethod
- def load_pybinding(cls, argtypes: List[str], source_code: str) -> Any:
- """
- Wrap a C++ function in fast Python bindings.
-
- Args:
- argtypes: The types of args to kernel(), e.g. ["float*", "long"]
- source_code: C++ source code containing a kernel() function
-
- Returns:
- A python version of kernel()
- """
- parseargs = ", ".join(
- f"parse_arg<{argtype.replace('const ', '')}>(args, {n})"
- for n, argtype in enumerate(argtypes)
- )
- suffix = cls.suffix_template % (len(argtypes), len(argtypes), parseargs)
- result = cls.load(source_code + suffix)
- assert isinstance(result, ModuleType)
- return result.kernel
-
-
class PyCodeCache:
cache: Dict[str, ModuleType] = dict()
linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()
@@ -2529,13 +2428,6 @@
return self.submit(task)
- def cpp_pybinding(self, argtypes: List[str], source_code: str) -> ModuleType:
- return self.submit(
- functools.partial(
- CppPythonBindingsCodeCache.load_pybinding, argtypes, source_code
- )
- )
-
def cuda(self, source_code, dst_file_ext):
def task():
return CUDACodeCache.load(source_code, dst_file_ext)[0]
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 4bf3c5a..16e08b9 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -667,10 +667,10 @@
)
def wrap_ptr_arg(self, buf, dtype):
- return buf
+ return f"c_void_p({buf}.data_ptr())"
def wrap_size_arg(self, size):
- return str(size)
+ return f"c_long({size})"
def cpp_argdefs(self):
from .cpp import DTYPE_TO_CPP, INDEX_TYPE
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 1dbd8df..07d3bfa 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -3376,6 +3376,7 @@
kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
arg_defs, call_args, arg_types = self.args.cpp_argdefs()
arg_defs = ",\n".ljust(25).join(arg_defs)
+ arg_types = ",".join(arg_types)
code = BracesBuffer()
# TODO: support kernel profile on other platforms
enable_kernel_profile = (
@@ -3402,7 +3403,7 @@
codecache_def = IndentedBuffer()
if not V.graph.cpp_wrapper:
- codecache_def.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
+ codecache_def.writeline("async_compile.cpp('''")
codecache_def.splice(code)
if not V.graph.cpp_wrapper:
codecache_def.writeline("''')")
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 4269b15..45b0e1f 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -650,13 +650,3 @@
return m;
}
-
-extern "C" void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
- if (C10_UNLIKELY(
- obj == nullptr ||
- (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)))) {
- throw std::runtime_error(
- "_torchinductor_pyobject_tensor_data_ptr: non-tensor input");
- }
- return THPVariable_Unpack(obj).data_ptr();
-}