[NNC] Added reductions to NNC python bindings. (#52492)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/52492

Reviewed By: bdhirsh

Differential Revision: D26575506

Pulled By: Chillee

fbshipit-source-id: 9a070f591a9709dab55dfff849184b1bcffc4fa5
diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
index 32d9a44..3bc1705 100644
--- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
+++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
@@ -1,3 +1,4 @@
+#include <pybind11/functional.h>
 #include <pybind11/operators.h>
 #include <torch/csrc/jit/python/pybind_utils.h>
 #include <torch/csrc/jit/tensorexpr/codegen.h>
@@ -223,8 +224,11 @@
   AT_FORALL_SCALAR_TYPES_AND2(Bool, Half, EXPRHANDLE_CTOR)
 #undef EXPRHANDLE_CTOR
 
-  py::class_<tensorexpr::VarHandle, tensorexpr::ExprHandle>(te, "VarHandle");
-  py::class_<tensorexpr::BufHandle, tensorexpr::ExprHandle>(te, "BufHandle");
+  py::class_<tensorexpr::VarHandle, tensorexpr::ExprHandle>(te, "VarHandle")
+      .def(py::init<const std::string&, tensorexpr::Dtype>());
+  py::class_<tensorexpr::BufHandle, tensorexpr::ExprHandle>( // NOLINT
+      te,
+      "BufHandle");
 
   py::class_<tensorexpr::Placeholder>(te, "Placeholder")
       .def(py::init<
@@ -296,16 +300,35 @@
         }
       },
       py::return_value_policy::reference);
-  py::class_<tensorexpr::Reducer>(te, "Reducer");
+  py::class_<tensorexpr::Reducer>(te, "Reducer")
+      .def(py::init<
+           tensorexpr::ExprHandle,
+           std::function<tensorexpr::ExprHandle(
+               tensorexpr::ExprHandle, tensorexpr::ExprHandle)>>());
 
+  py::class_<tensorexpr::Sum, tensorexpr::Reducer>(te, "Sum").def(py::init<>());
+  py::class_<tensorexpr::Maximum, tensorexpr::Reducer>(te, "Maximum")
+      .def(py::init<tensorexpr::Dtype>());
   te.def(
-      "SumReduce",
+      "Reduce",
       [](const std::string& func_name,
          const std::vector<tensorexpr::DimArg>& dim_args,
+         const tensorexpr::Reducer& reducer,
          tensorexpr::Tensor* buffer,
          const std::vector<tensorexpr::DimArg>& reduce_args) {
         return tensorexpr::Reduce(
-            func_name, dim_args, tensorexpr::Sum(), buffer, reduce_args);
+            func_name, dim_args, reducer, buffer, reduce_args);
+      },
+      py::return_value_policy::reference);
+  te.def(
+      "Reduce",
+      [](const std::string& func_name,
+         const std::vector<tensorexpr::DimArg>& dim_args,
+         const tensorexpr::Reducer& reducer,
+         const tensorexpr::Placeholder& buffer,
+         const std::vector<tensorexpr::DimArg>& reduce_args) {
+        return tensorexpr::Reduce(
+            func_name, dim_args, reducer, buffer, reduce_args);
       },
       py::return_value_policy::reference);