[TensorExpr] Fix some TE python bindings. (#68232)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/68232
Differential Revision:
D32380676
D32380676
Test Plan: Imported from OSS
Reviewed By: saketh-are
Pulled By: ZolotukhinM
fbshipit-source-id: 9287a2c765a53b45ac04d625cc010f5384a8bddf
diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
index 98bdf8d..62e6f69 100644
--- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
+++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
@@ -453,6 +453,7 @@
py::class_<LoopNest>(te, "LoopNest")
.def(py::init<const std::vector<Tensor>&>())
+ .def(py::init<const std::vector<Tensor>&, const std::vector<Tensor>&>())
.def(py::init([](StmtPtr s, const std::vector<BufHandle>& bufs) {
std::unordered_set<BufPtr> buf_nodes;
for (auto& buf : bufs) {
@@ -623,7 +624,7 @@
return LoopNest::compressBuffer(buf.node(), stmt);
},
py::return_value_policy::reference)
- .def(
+ .def_static(
"cache_accesses",
[](const BufHandle& producer,
const std::string& name,
@@ -633,7 +634,7 @@
return std::make_pair(BufHandle(ret.first), ret.second);
},
py::return_value_policy::reference)
- .def(
+ .def_static(
"compute_at",
[](StmtPtr s, ForPtr at) { LoopNest::computeAt(s, at); })
.def(