[NNC] Added reductions to NNC python bindings. (#52492)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52492
Reviewed By: bdhirsh
Differential Revision: D26575506
Pulled By: Chillee
fbshipit-source-id: 9a070f591a9709dab55dfff849184b1bcffc4fa5
diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
index 32d9a44..3bc1705 100644
--- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
+++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
@@ -1,3 +1,4 @@
+#include <pybind11/functional.h>
#include <pybind11/operators.h>
#include <torch/csrc/jit/python/pybind_utils.h>
#include <torch/csrc/jit/tensorexpr/codegen.h>
@@ -223,8 +224,11 @@
AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, EXPRHANDLE_CTOR)
#undef EXPRHANDLE_CTOR
- py::class_<tensorexpr::VarHandle, tensorexpr::ExprHandle>(te, "VarHandle");
- py::class_<tensorexpr::BufHandle, tensorexpr::ExprHandle>(te, "BufHandle");
+ py::class_<tensorexpr::VarHandle, tensorexpr::ExprHandle>(te, "VarHandle")
+ .def(py::init<const std::string&, tensorexpr::Dtype>());
+ py::class_<tensorexpr::BufHandle, tensorexpr::ExprHandle>( // NOLINT
+ te,
+ "BufHandle");
py::class_<tensorexpr::Placeholder>(te, "Placeholder")
.def(py::init<
@@ -296,16 +300,35 @@
}
},
py::return_value_policy::reference);
- py::class_<tensorexpr::Reducer>(te, "Reducer");
+ py::class_<tensorexpr::Reducer>(te, "Reducer")
+ .def(py::init<
+ tensorexpr::ExprHandle,
+ std::function<tensorexpr::ExprHandle(
+ tensorexpr::ExprHandle, tensorexpr::ExprHandle)>>());
+ py::class_<tensorexpr::Sum, tensorexpr::Reducer>(te, "Sum").def(py::init<>());
+ py::class_<tensorexpr::Maximum, tensorexpr::Reducer>(te, "Maximum")
+ .def(py::init<tensorexpr::Dtype>());
te.def(
- "SumReduce",
+ "Reduce",
[](const std::string& func_name,
const std::vector<tensorexpr::DimArg>& dim_args,
+ const tensorexpr::Reducer& reducer,
tensorexpr::Tensor* buffer,
const std::vector<tensorexpr::DimArg>& reduce_args) {
return tensorexpr::Reduce(
- func_name, dim_args, tensorexpr::Sum(), buffer, reduce_args);
+ func_name, dim_args, reducer, buffer, reduce_args);
+ },
+ py::return_value_policy::reference);
+ te.def(
+ "Reduce",
+ [](const std::string& func_name,
+ const std::vector<tensorexpr::DimArg>& dim_args,
+ const tensorexpr::Reducer& reducer,
+ const tensorexpr::Placeholder& buffer,
+ const std::vector<tensorexpr::DimArg>& reduce_args) {
+ return tensorexpr::Reduce(
+ func_name, dim_args, reducer, buffer, reduce_args);
},
py::return_value_policy::reference);