Revert "[inductor] Faster C++ kernel python bindings (#117500)"

This reverts commit bb0fd1bd3ca145b77159427bc5bacf5f98ec3896.

Reverted https://github.com/pytorch/pytorch/pull/117500 on behalf of https://github.com/PaliC due to breaking internal discussed with author offline ([comment](https://github.com/pytorch/pytorch/pull/117500#issuecomment-1896516512))
diff --git a/benchmarks/dynamo/microbenchmarks/overheads.py b/benchmarks/dynamo/microbenchmarks/overheads.py
deleted file mode 100644
index 93d1ce6..0000000
--- a/benchmarks/dynamo/microbenchmarks/overheads.py
+++ /dev/null
@@ -1,30 +0,0 @@
-import time
-import timeit
-
-import numpy as np
-
-import torch
-
-
-def add1(x):
-    return x + 1
-
-
-def bench(name, fn):
-    x = torch.randn(1)
-    start = time.perf_counter()
-    for _ in range(3):
-        fn(x)
-    end = time.perf_counter()
-
-    results = timeit.repeat(lambda: fn(x), number=1000, repeat=100)
-    print(f"{name} {np.median(results)*1000:.1f}us (warmup={end-start:.1f}s)")
-
-
-def main():
-    bench("eager   ", add1)
-    bench("compiled", torch.compile(add1))
-
-
-if __name__ == "__main__":
-    main()
diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py
index 781517c..c7ba622 100644
--- a/torch/_inductor/codecache.py
+++ b/torch/_inductor/codecache.py
@@ -23,7 +23,6 @@
 import sys
 import sysconfig
 import tempfile
-import textwrap
 import threading
 import warnings
 import weakref
@@ -1168,13 +1167,7 @@
 
 
 def get_shared(shared: bool = True) -> str:
-    if not shared:
-        return ""
-    if platform.system() == "Darwin" and "clang" in cpp_compiler():
-        # This causes undefined symbols to behave the same as linux
-        return "-shared -fPIC -undefined dynamic_lookup"
-    else:
-        return "-shared -fPIC"
+    return "-shared -fPIC" if shared else ""
 
 
 def get_warning_all_flag(warning_all: bool = True) -> str:
@@ -1792,24 +1785,19 @@
 
 
 class CppCodeCache:
-    cache: Dict[str, Union[CDLL, ModuleType]] = {}
+    cache: Dict[str, CDLL] = dict()
     clear = staticmethod(cache.clear)
-    cpp_compile_command_flags: Dict[str, Any] = {}
 
     @staticmethod
-    def _load_library_inner(path: str, key: str) -> Union[CDLL, ModuleType]:
-        return cdll.LoadLibrary(path)
-
-    @classmethod
-    def _load_library(cls, path: str, key: str) -> Union[CDLL, ModuleType]:
+    def _load_library(path: str) -> CDLL:
         try:
-            return cls._load_library_inner(path, key)
-        except (ImportError, OSError) as e:
+            return cdll.LoadLibrary(path)
+        except OSError as e:
             if "gomp" in str(e) and os.path.exists("/usr/lib64/libgomp.so.1"):
                 # hacky workaround for fbcode/buck
                 global _libgomp
                 _libgomp = cdll.LoadLibrary("/usr/lib64/libgomp.so.1")
-                return cls._load_library_inner(path, key)
+                return cdll.LoadLibrary(path)
             if "failed to map segment from shared object" in str(e):
                 raise OSError(
                     f"{e}.  The most common reason this may occur is if the {tempfile.gettempdir()} folder "
@@ -1820,13 +1808,9 @@
             raise
 
     @classmethod
-    def load(cls, source_code: str) -> Union[CDLL, ModuleType]:
+    def load(cls, source_code: str) -> CDLL:
         picked_vec_isa = pick_vec_isa()
-        cpp_command = repr(
-            cpp_compile_command(
-                "i", "o", vec_isa=picked_vec_isa, **cls.cpp_compile_command_flags
-            )
-        )
+        cpp_command = repr(cpp_compile_command("i", "o", vec_isa=picked_vec_isa))
         key, input_path = write(source_code, "cpp", extra=cpp_command)
         if key not in cls.cache:
             from filelock import FileLock
@@ -1838,101 +1822,16 @@
                 if not os.path.exists(output_path):
                     cmd = shlex.split(
                         cpp_compile_command(
-                            input=input_path,
-                            output=output_path,
-                            vec_isa=picked_vec_isa,
-                            **cls.cpp_compile_command_flags,
+                            input=input_path, output=output_path, vec_isa=picked_vec_isa
                         )
                     )
                     compile_file(input_path, output_path, cmd)
-                cls.cache[key] = cls._load_library(output_path, key)
-                cls.cache[key].key = key  # type: ignore[union-attr]
+                cls.cache[key] = cls._load_library(output_path)
+                cls.cache[key].key = key  # type: ignore[attr-defined]
 
         return cls.cache[key]
 
 
-class CppPythonBindingsCodeCache(CppCodeCache):
-    cache: Dict[str, Union[CDLL, ModuleType]] = {}
-    clear = staticmethod(cache.clear)
-    cpp_compile_command_flags = {
-        "include_pytorch": True,
-        "shared": True,
-    }
-    suffix_template = textwrap.dedent(
-        """
-        // Python bindings to call kernel():
-        #define PY_SSIZE_T_CLEAN
-        #include <Python.h>
-
-        // This is defined in guards.cpp so we don't need to import PyTorch headers that are slooow
-        extern "C" void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj);
-
-        template <typename T> static inline T parse_arg(PyObject* args, size_t n) {
-            static_assert(std::is_pointer<T>::value, "arg type must be pointer or long");
-            return static_cast<T>(_torchinductor_pyobject_tensor_data_ptr(PyTuple_GET_ITEM(args, n)));
-        }
-        template <> inline long parse_arg<long>(PyObject* args, size_t n) {
-            auto result = PyLong_AsSsize_t(PyTuple_GET_ITEM(args, n));
-            if(result == -1 && PyErr_Occurred())
-                [[unlikely]] throw std::runtime_error("expected int arg");
-            return result;
-        }
-
-        static PyObject* kernel_py(PyObject* self, PyObject* args) {
-            try {
-                if(!PyTuple_CheckExact(args))
-                    [[unlikely]] throw std::runtime_error("tuple args required");
-                if(PyTuple_GET_SIZE(args) != %s)
-                    [[unlikely]] throw std::runtime_error("requires %s args");
-                kernel(%s);
-                Py_RETURN_NONE;
-            } catch(std::exception const& e) {
-                PyErr_SetString(PyExc_RuntimeError, e.what());
-                return nullptr;
-            }
-        }
-
-        static PyMethodDef py_methods[] = {
-            {"kernel", kernel_py, METH_VARARGS, ""},
-            {NULL, NULL, 0, NULL}};
-
-        static struct PyModuleDef py_module =
-            {PyModuleDef_HEAD_INIT, "kernel", NULL, -1, py_methods};
-
-        PyMODINIT_FUNC PyInit_kernel(void) {
-            return PyModule_Create(&py_module);
-        }
-        """
-    )
-
-    @classmethod
-    def _load_library_inner(cls, path: str, key: str) -> ModuleType:
-        return importlib.machinery.ExtensionFileLoader(
-            f"{key}.kernel", path
-        ).load_module()  # type: ignore[call-arg]
-
-    @classmethod
-    def load_pybinding(cls, argtypes: List[str], source_code: str) -> Any:
-        """
-        Wrap a C++ function in fast Python bindings.
-
-        Args:
-            argtypes: The types of args to kernel(), e.g. ["float*", "long"]
-            source_code: C++ source code containing a kernel() function
-
-        Returns:
-            A python version of kernel()
-        """
-        parseargs = ", ".join(
-            f"parse_arg<{argtype.replace('const ', '')}>(args, {n})"
-            for n, argtype in enumerate(argtypes)
-        )
-        suffix = cls.suffix_template % (len(argtypes), len(argtypes), parseargs)
-        result = cls.load(source_code + suffix)
-        assert isinstance(result, ModuleType)
-        return result.kernel
-
-
 class PyCodeCache:
     cache: Dict[str, ModuleType] = dict()
     linemaps: Dict[str, List[Tuple[Any, ...]]] = dict()
@@ -2529,13 +2428,6 @@
 
         return self.submit(task)
 
-    def cpp_pybinding(self, argtypes: List[str], source_code: str) -> ModuleType:
-        return self.submit(
-            functools.partial(
-                CppPythonBindingsCodeCache.load_pybinding, argtypes, source_code
-            )
-        )
-
     def cuda(self, source_code, dst_file_ext):
         def task():
             return CUDACodeCache.load(source_code, dst_file_ext)[0]
diff --git a/torch/_inductor/codegen/common.py b/torch/_inductor/codegen/common.py
index 4bf3c5a..16e08b9 100644
--- a/torch/_inductor/codegen/common.py
+++ b/torch/_inductor/codegen/common.py
@@ -667,10 +667,10 @@
         )
 
     def wrap_ptr_arg(self, buf, dtype):
-        return buf
+        return f"c_void_p({buf}.data_ptr())"
 
     def wrap_size_arg(self, size):
-        return str(size)
+        return f"c_long({size})"
 
     def cpp_argdefs(self):
         from .cpp import DTYPE_TO_CPP, INDEX_TYPE
diff --git a/torch/_inductor/codegen/cpp.py b/torch/_inductor/codegen/cpp.py
index 1dbd8df..07d3bfa 100644
--- a/torch/_inductor/codegen/cpp.py
+++ b/torch/_inductor/codegen/cpp.py
@@ -3376,6 +3376,7 @@
         kernel_name = "_".join(["cpp", fused_name, wrapper.next_kernel_suffix()])
         arg_defs, call_args, arg_types = self.args.cpp_argdefs()
         arg_defs = ",\n".ljust(25).join(arg_defs)
+        arg_types = ",".join(arg_types)
         code = BracesBuffer()
         # TODO: support kernel profile on other platforms
         enable_kernel_profile = (
@@ -3402,7 +3403,7 @@
 
         codecache_def = IndentedBuffer()
         if not V.graph.cpp_wrapper:
-            codecache_def.writeline(f"async_compile.cpp_pybinding({arg_types!r}, '''")
+            codecache_def.writeline("async_compile.cpp('''")
         codecache_def.splice(code)
         if not V.graph.cpp_wrapper:
             codecache_def.writeline("''')")
diff --git a/torch/csrc/dynamo/guards.cpp b/torch/csrc/dynamo/guards.cpp
index 4269b15..45b0e1f 100644
--- a/torch/csrc/dynamo/guards.cpp
+++ b/torch/csrc/dynamo/guards.cpp
@@ -650,13 +650,3 @@
 
   return m;
 }
-
-extern "C" void* _torchinductor_pyobject_tensor_data_ptr(PyObject* obj) {
-  if (C10_UNLIKELY(
-          obj == nullptr ||
-          (!THPVariable_CheckExact(obj) && !THPVariable_Check(obj)))) {
-    throw std::runtime_error(
-        "_torchinductor_pyobject_tensor_data_ptr: non-tensor input");
-  }
-  return THPVariable_Unpack(obj).data_ptr();
-}