Move JIT passes to a separate directory
diff --git a/setup.py b/setup.py
index 9848d08..29778fc 100644
--- a/setup.py
+++ b/setup.py
@@ -374,13 +374,13 @@
"torch/csrc/jit/init.cpp",
"torch/csrc/jit/ir.cpp",
"torch/csrc/jit/python_ir.cpp",
- "torch/csrc/jit/graph_fuser.cpp",
- "torch/csrc/jit/init_pass.cpp",
- "torch/csrc/jit/dead_code_elimination.cpp",
"torch/csrc/jit/test_jit.cpp",
"torch/csrc/jit/tracer.cpp",
"torch/csrc/jit/python_tracer.cpp",
"torch/csrc/jit/interned_strings.cpp",
+ "torch/csrc/jit/passes/graph_fuser.cpp",
+ "torch/csrc/jit/passes/init_pass.cpp",
+ "torch/csrc/jit/passes/dead_code_elimination.cpp",
"torch/csrc/autograd/init.cpp",
"torch/csrc/autograd/engine.cpp",
"torch/csrc/autograd/function.cpp",
diff --git a/torch/csrc/Module.cpp b/torch/csrc/Module.cpp
index 43d1511..7027205 100644
--- a/torch/csrc/Module.cpp
+++ b/torch/csrc/Module.cpp
@@ -558,7 +558,6 @@
extern PyObject * THCSPModule_initExtension(PyObject *self);
#endif
-extern PyObject * THPJIT_initExtension(PyObject *self);
static PyMethodDef TorchMethods[] = {
{"_initExtension", (PyCFunction)THPModule_initExtension, METH_O, NULL},
@@ -799,7 +798,6 @@
#endif
THPUtils_addPyMethodDefs(methods, TorchMethods);
- THPUtils_addPyMethodDefs(methods, THPJIT_methods());
#ifdef WITH_CUDNN
THPUtils_addPyMethodDefs(methods, THCUDNN_methods());
#endif
@@ -826,9 +824,8 @@
ASSERT_TRUE(THPVariable_initModule(module));
ASSERT_TRUE(THPFunction_initModule(module));
ASSERT_TRUE(THPEngine_initModule(module));
- torch::jit::initPythonIRBindings(module);
- torch::jit::initPythonTracerBindings(module);
torch::autograd::initAutogradClosureBindings(module);
+ torch::jit::initJITBindings(module);
ASSERT_TRUE(THPDoubleStorage_init(module));
ASSERT_TRUE(THPFloatStorage_init(module));
ASSERT_TRUE(THPHalfStorage_init(module));
diff --git a/torch/csrc/jit/init.cpp b/torch/csrc/jit/init.cpp
index 6651117..8130369 100644
--- a/torch/csrc/jit/init.cpp
+++ b/torch/csrc/jit/init.cpp
@@ -1,69 +1,49 @@
-#include <Python.h>
-#include <pybind11/pybind11.h>
+#include "torch/csrc/utils/pybind.h"
-namespace py = pybind11;
-
-#include "THP.h"
-#include "torch/csrc/jit/ir.h"
-#include "torch/csrc/jit/graph_fuser.h"
-#include "torch/csrc/jit/init_pass.h"
-#include "torch/csrc/jit/dead_code_elimination.h"
#include "torch/csrc/jit/python_tracer.h"
-#include "torch/csrc/utils/python_strings.h"
-#include "torch/csrc/DynamicTypes.h"
+#include "torch/csrc/jit/python_ir.h"
+#include "torch/csrc/jit/passes/graph_fuser.h"
+#include "torch/csrc/jit/passes/init_pass.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
#include "torch/csrc/onnx/export.h"
-PyObject * THPJIT_initExtension(PyObject *_unused)
-{
+
+
+namespace torch { namespace jit {
+
+namespace {
+
+bool loadPythonClasses() {
// Leaving this code here, because it will likely be useful at some point
//PyObject *jit_module = PyImport_ImportModule("torch.jit");
//THPUtils_assert(jit_module, "class loader couldn't access "
//"torch.jit module");
//PyObject *jit_dict = PyModule_GetDict(jit_module);
- Py_RETURN_TRUE;
+ return true;
}
-// stub to run all C++ only tests for the JIT
-// the stuff in test_jit.cpp is kept separate from the rest of PyTorch
-// so we can build and iterate on it faster.
-// from test_jit.cpp
-namespace torch { namespace jit { extern void runJITCPPTests(); } };
-
-namespace {
-
-using namespace torch::jit;
-
-using pass_type = void (std::shared_ptr<Graph>&);
-
-template<pass_type pass>
-PyObject * wrap_pass(PyObject *_unused, PyObject *py_state) {
- HANDLE_TH_ERRORS
- auto trace = py::handle(py_state).cast<tracer::TracingState*>();
- pass(trace->graph);
- Py_RETURN_NONE;
- END_HANDLE_TH_ERRORS
+template<void (*F)(std::shared_ptr<Graph>& graph)>
+void graph_pass(const std::shared_ptr<tracer::TracingState>& state) {
+ return F(state->graph);
}
-PyObject * run_cpp_tests(PyObject *_unused, PyObject *_unused2) {
- HANDLE_TH_ERRORS
- runJITCPPTests();
- Py_RETURN_NONE;
- END_HANDLE_TH_ERRORS
-}
-
-struct PyMethodDef _THPJIT_methods[] = {
- {"_jit_init", (PyCFunction)THPJIT_initExtension, METH_NOARGS, NULL},
- {"_jit_pass_init", (PyCFunction)wrap_pass<MatchJITOps>, METH_O, "init"},
- {"_jit_pass_fuse", (PyCFunction)wrap_pass<FuseGraph>, METH_O, "fuse"},
- {"_jit_pass_dce", (PyCFunction)wrap_pass<EliminateDeadCode>, METH_O, "dce"},
- {"_jit_pass_lint", (PyCFunction)wrap_pass<LintGraph>, METH_O, "lint"},
- {"_jit_run_cpp_tests",(PyCFunction)run_cpp_tests, METH_NOARGS, NULL},
- {NULL}
-};
-
} // anonymous namespace
-PyMethodDef* THPJIT_methods() {
- return _THPJIT_methods;
+extern void runJITCPPTests();
+
+void initJITBindings(PyObject *module) {
+ auto m = py::handle(module).cast<py::module>();
+
+ m.def("_jit_init", loadPythonClasses)
+ .def("_jit_pass_init", graph_pass<MatchJITOps>)
+ .def("_jit_pass_fuse", graph_pass<FuseGraph>)
+ .def("_jit_pass_dce", graph_pass<EliminateDeadCode>)
+ .def("_jit_pass_lint", graph_pass<LintGraph>)
+ .def("_jit_run_cpp_tests", runJITCPPTests);
+
+ initPythonIRBindings(module);
+ initPythonTracerBindings(module);
}
+
+}}
diff --git a/torch/csrc/jit/init.h b/torch/csrc/jit/init.h
index 866df08..fbc902e 100644
--- a/torch/csrc/jit/init.h
+++ b/torch/csrc/jit/init.h
@@ -1,3 +1,7 @@
#pragma once
-PyMethodDef* THPJIT_methods();
+namespace torch { namespace jit {
+
+void initJITBindings(PyObject *module);
+
+}}
diff --git a/torch/csrc/jit/dead_code_elimination.cpp b/torch/csrc/jit/passes/dead_code_elimination.cpp
similarity index 92%
rename from torch/csrc/jit/dead_code_elimination.cpp
rename to torch/csrc/jit/passes/dead_code_elimination.cpp
index 60e6f1c..640d284 100644
--- a/torch/csrc/jit/dead_code_elimination.cpp
+++ b/torch/csrc/jit/passes/dead_code_elimination.cpp
@@ -1,4 +1,4 @@
-#include "torch/csrc/jit/dead_code_elimination.h"
+#include "torch/csrc/jit/passes/dead_code_elimination.h"
namespace torch { namespace jit {
diff --git a/torch/csrc/jit/dead_code_elimination.h b/torch/csrc/jit/passes/dead_code_elimination.h
similarity index 100%
rename from torch/csrc/jit/dead_code_elimination.h
rename to torch/csrc/jit/passes/dead_code_elimination.h
diff --git a/torch/csrc/jit/graph_fuser.cpp b/torch/csrc/jit/passes/graph_fuser.cpp
similarity index 99%
rename from torch/csrc/jit/graph_fuser.cpp
rename to torch/csrc/jit/passes/graph_fuser.cpp
index 2ab7303..b7b1d41 100644
--- a/torch/csrc/jit/graph_fuser.cpp
+++ b/torch/csrc/jit/passes/graph_fuser.cpp
@@ -1,4 +1,4 @@
-#include "torch/csrc/jit/graph_fuser.h"
+#include "torch/csrc/jit/passes/graph_fuser.h"
#include <unordered_map>
namespace torch { namespace jit {
diff --git a/torch/csrc/jit/graph_fuser.h b/torch/csrc/jit/passes/graph_fuser.h
similarity index 100%
rename from torch/csrc/jit/graph_fuser.h
rename to torch/csrc/jit/passes/graph_fuser.h
diff --git a/torch/csrc/jit/init_pass.cpp b/torch/csrc/jit/passes/init_pass.cpp
similarity index 98%
rename from torch/csrc/jit/init_pass.cpp
rename to torch/csrc/jit/passes/init_pass.cpp
index 03463af..7fad109 100644
--- a/torch/csrc/jit/init_pass.cpp
+++ b/torch/csrc/jit/passes/init_pass.cpp
@@ -1,4 +1,4 @@
-#include "torch/csrc/jit/init_pass.h"
+#include "torch/csrc/jit/passes/init_pass.h"
#include <unordered_map>
namespace torch { namespace jit {
diff --git a/torch/csrc/jit/init_pass.h b/torch/csrc/jit/passes/init_pass.h
similarity index 100%
rename from torch/csrc/jit/init_pass.h
rename to torch/csrc/jit/passes/init_pass.h
diff --git a/torch/csrc/jit/python_ir.h b/torch/csrc/jit/python_ir.h
index a980889..f9c890e 100644
--- a/torch/csrc/jit/python_ir.h
+++ b/torch/csrc/jit/python_ir.h
@@ -1,5 +1,9 @@
#pragma once
+#include "torch/csrc/jit/ir.h"
+
namespace torch { namespace jit {
+
void initPythonIRBindings(PyObject* module);
+
}}
diff --git a/torch/csrc/jit/tracer.h b/torch/csrc/jit/tracer.h
index c32249c..8ef797d 100644
--- a/torch/csrc/jit/tracer.h
+++ b/torch/csrc/jit/tracer.h
@@ -6,7 +6,6 @@
#include "torch/csrc/utils/functional.h"
#include "torch/csrc/autograd/function_hook.h"
#include "torch/csrc/autograd/variable.h"
-#include "torch/csrc/jit/init_pass.h"
#include <memory>
#include <mutex>
diff --git a/torch/csrc/onnx/export.cpp b/torch/csrc/onnx/export.cpp
index fb0978d..7c7a588 100644
--- a/torch/csrc/onnx/export.cpp
+++ b/torch/csrc/onnx/export.cpp
@@ -6,7 +6,6 @@
#include "torch/csrc/onnx.h"
#include "torch/csrc/autograd/functions/convolution.h"
-#include "torch/csrc/jit/dead_code_elimination.h"
#include "torch/csrc/utils/functional.h"
#include <ATen/ATen.h>