[TensorExpr] Switch from `ExprPtr` to `ExprHandle` in Compute impl. (#72389)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/72389

This is an NFC change that just prepares the code for the upcoming
deletion of `DimArg` class. This change makes `Compute` and `Reduce`
APIs to use `ExprHandle` everywhere.

There should be no observable behavior change from this PR.

Test Plan: Imported from OSS

Reviewed By: navahgar

Differential Revision: D34030295

Pulled By: ZolotukhinM

fbshipit-source-id: 3fd035b6a6bd0a07ccfa92e118819478ae85412a
(cherry picked from commit 1b0a4b6fac54aa4d4735df435d345a30ba0d8a53)
diff --git a/test/cpp/tensorexpr/test_memdependency.cpp b/test/cpp/tensorexpr/test_memdependency.cpp
index 7019353..31c069c 100644
--- a/test/cpp/tensorexpr/test_memdependency.cpp
+++ b/test/cpp/tensorexpr/test_memdependency.cpp
@@ -543,8 +543,7 @@
    */
 
   StorePtr aInit = Store::make(a, {0}, 0);
-  ExprHandle reduce =
-      ExprHandle(Sum()(a.node(), ExprHandle(1), {x.node()}, {x.node()}));
+  ExprHandle reduce = Sum()(a, 1, {x}, {x});
   StorePtr aReduce = Store::make(a, {0}, reduce);
   StmtPtr loop = For::make(x, 0, 10, aReduce);
   StorePtr bStore = Store::make(b, {0}, Load::make(a, {0}));
diff --git a/torch/csrc/jit/tensorexpr/expr.cpp b/torch/csrc/jit/tensorexpr/expr.cpp
index e17727c..fb47c58 100644
--- a/torch/csrc/jit/tensorexpr/expr.cpp
+++ b/torch/csrc/jit/tensorexpr/expr.cpp
@@ -425,16 +425,30 @@
   TORCH_CHECK(var);
 }
 
-ExprHandle Buf::make(
-    const std::string& name_hint,
-    const std::vector<ExprHandle>& dims,
-    Dtype dtype) {
-  return ExprHandle(
-      alloc<Buf>(name_hint, ExprHandleVectorToExprVector(dims), dtype));
+BufHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
+  return Buf::make("", dims, dtype);
 }
 
-ExprHandle Buf::make(const std::vector<ExprHandle>& dims, Dtype dtype) {
-  return Buf::make("", dims, dtype);
+BufHandle Buf::make(
+    const std::string& name_hint,
+    const std::vector<ExprHandle>& dims,
+    Dtype dtype,
+    c10::optional<ExprHandle> initializer,
+    c10::optional<std::vector<ExprHandle>> strides,
+    c10::optional<ExprHandle> qscale,
+    c10::optional<ExprHandle> qzero) {
+  c10::optional<std::vector<ExprPtr>> opt_strides;
+  if (strides) {
+    opt_strides = ExprHandleVectorToExprVector(*strides);
+  }
+  return BufHandle(alloc<Buf>(
+      name_hint,
+      ExprHandleVectorToExprVector(dims),
+      dtype,
+      initializer ? initializer->node() : nullptr,
+      opt_strides,
+      qscale ? qscale->node() : nullptr,
+      qzero ? qzero->node() : nullptr));
 }
 
 std::vector<ExprHandle> BufHandle::dims() const {
diff --git a/torch/csrc/jit/tensorexpr/expr.h b/torch/csrc/jit/tensorexpr/expr.h
index 148ce41..cf75e63 100644
--- a/torch/csrc/jit/tensorexpr/expr.h
+++ b/torch/csrc/jit/tensorexpr/expr.h
@@ -191,11 +191,16 @@
 
 class TORCH_API Buf : public ExprNode<Buf> {
  public:
-  static ExprHandle make(
+  static BufHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
+
+  static BufHandle make(
       const std::string& name_hint,
       const std::vector<ExprHandle>& dims,
-      Dtype dtype);
-  static ExprHandle make(const std::vector<ExprHandle>& dims, Dtype dtype);
+      Dtype dtype,
+      c10::optional<ExprHandle> initializer = c10::nullopt,
+      c10::optional<std::vector<ExprHandle>> strides = c10::nullopt,
+      c10::optional<ExprHandle> qscale = c10::nullopt,
+      c10::optional<ExprHandle> qzero = c10::nullopt);
 
   // TODO: unique_name
   VarPtr base_handle() const {
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index cc49ef3..93e76f6 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -2955,7 +2955,7 @@
         tmp_params,
         reduceOp->reducer()(
             producer,
-            ExprHandle(alloc<Load>(tmp_buf, new_loop_vars_expr)),
+            alloc<Load>(tmp_buf, new_loop_vars_expr),
             tmp_params,
             {}));
 
diff --git a/torch/csrc/jit/tensorexpr/operators/quantization.cpp b/torch/csrc/jit/tensorexpr/operators/quantization.cpp
index 1c99f05..a3f1a91 100644
--- a/torch/csrc/jit/tensorexpr/operators/quantization.cpp
+++ b/torch/csrc/jit/tensorexpr/operators/quantization.cpp
@@ -727,8 +727,8 @@
   auto input_height = ExprHandle(A.dim(2));
   auto input_width = ExprHandle(A.dim(3));
 
-  std::vector<ExprPtr> dims;
-  std::vector<VarPtr> args;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> args;
   unpack_dim_args(c10::fmap<DimArg>(outputShape), &dims, &args);
   // Handle separately when scale is specified? as in 'scalar_t
   // compute_scales_value' in UpSample.h
@@ -749,16 +749,16 @@
     newAxes[3] = compute_nearest_idx(scale_w, axes[3], input_width);
     return A.load(newAxes);
   };
-  auto e = body_func(VarVectorToVarHandleVector(args));
-  BufPtr buf = alloc<Buf>(
+  auto e = body_func(args);
+  BufHandle buf = Buf::make(
       "quantize_upsample_nearest2d",
-      ExprHandleVectorToExprVector(outputShape),
+      outputShape,
       Dtype(*outputType),
-      nullptr,
       c10::nullopt,
-      A.node()->qscale(),
-      A.node()->qzero());
-  return Tensor(buf, args, e.node());
+      c10::nullopt,
+      ExprHandle(A.node()->qscale()),
+      ExprHandle(A.node()->qzero()));
+  return Tensor(buf, args, e);
 }
 
 Tensor computeUpsampleNearest2dExternalCall(
diff --git a/torch/csrc/jit/tensorexpr/reduction.cpp b/torch/csrc/jit/tensorexpr/reduction.cpp
index 1727482..0cf8cb2 100644
--- a/torch/csrc/jit/tensorexpr/reduction.cpp
+++ b/torch/csrc/jit/tensorexpr/reduction.cpp
@@ -6,12 +6,12 @@
 namespace jit {
 namespace tensorexpr {
 
-ReduceOpPtr Reducer::operator()(
-    BufPtr result_buf,
+ExprHandle Reducer::operator()(
+    BufHandle result_buf,
     ExprHandle body,
-    const std::vector<ExprPtr>& output,
-    const std::vector<VarPtr>& inner) const {
-  return alloc<ReduceOp>(
+    const std::vector<ExprHandle>& output,
+    const std::vector<VarHandle>& inner) const {
+  return ReduceOp::make(
       complete(result_buf, interaction_, body, output, inner), inner, *this);
 }
 
@@ -26,6 +26,14 @@
       *this);
 }
 
+ExprHandle ReduceOp::make(
+    ExprHandle body,
+    std::vector<VarHandle> reduce_args,
+    const Reducer& reducer) {
+  return ExprHandle(alloc<ReduceOp>(
+      body.node(), VarHandleVectorToVarVector(reduce_args), reducer));
+}
+
 } // namespace tensorexpr
 } // namespace jit
 } // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/reduction.h b/torch/csrc/jit/tensorexpr/reduction.h
index 24df1ef..c1e8401 100644
--- a/torch/csrc/jit/tensorexpr/reduction.h
+++ b/torch/csrc/jit/tensorexpr/reduction.h
@@ -36,11 +36,11 @@
     return init_;
   }
 
-  ReduceOpPtr operator()(
-      BufPtr result_buf,
+  ExprHandle operator()(
+      BufHandle result_buf,
       ExprHandle body,
-      const std::vector<ExprPtr>& output,
-      const std::vector<VarPtr>& inner) const;
+      const std::vector<ExprHandle>& output,
+      const std::vector<VarHandle>& inner) const;
 
   ReduceOpPtr operator()(
       BufPtr result_buf,
@@ -111,6 +111,16 @@
     auto e = interaction(accum, body);
     return e.node();
   }
+  static ExprHandle complete(
+      BufHandle accumulator,
+      ReduceInteraction interaction,
+      ExprHandle body,
+      const std::vector<ExprHandle>& output_args,
+      const std::vector<VarHandle>& reduce_args) {
+    ExprHandle accum = Load::make(body.dtype(), accumulator, output_args);
+    auto e = interaction(accum, body);
+    return e;
+  }
 
  private:
   ExprPtr init_;
@@ -133,6 +143,10 @@
         body_(body),
         reduce_args_(std::move(reduce_args)),
         reducer_(reducer) {}
+  static ExprHandle make(
+      ExprHandle body,
+      std::vector<VarHandle> reduce_args,
+      const Reducer& reducer);
 
   // return the body expression which obtains the value to be reduced.
   ExprPtr body() const {
diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp
index c78f27f..10c6f96 100644
--- a/torch/csrc/jit/tensorexpr/tensor.cpp
+++ b/torch/csrc/jit/tensorexpr/tensor.cpp
@@ -53,11 +53,11 @@
     const std::string& name,
     const std::vector<DimArg>& dim_args,
     const std::function<ExprHandle(const std::vector<VarHandle>&)>& body_func) {
-  std::vector<ExprPtr> dims;
-  std::vector<VarPtr> args;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> args;
   unpack_dim_args(dim_args, &dims, &args);
-  ExprPtr body = body_func(VarVectorToVarHandleVector(args)).node();
-  BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+  ExprHandle body = body_func(args);
+  BufHandle buf = Buf::make(name, dims, body.dtype());
   return Tensor(buf, args, body);
 }
 
@@ -69,11 +69,11 @@
     throw malformed_input("mismatch between body and arg size (1)");
   }
 
-  std::vector<ExprPtr> dims;
-  std::vector<VarPtr> args;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> args;
   unpack_dim_args(dim_args, &dims, &args);
-  ExprPtr body = body_func(VarHandle(args[0])).node();
-  BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+  ExprHandle body = body_func(args[0]);
+  BufHandle buf = Buf::make(name, dims, body.dtype());
   return Tensor(buf, args, body);
 }
 
@@ -85,11 +85,11 @@
   if (dim_args.size() != 2) {
     throw malformed_input("mismatch between body and arg size (2)");
   }
-  std::vector<ExprPtr> dims;
-  std::vector<VarPtr> args;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> args;
   unpack_dim_args(dim_args, &dims, &args);
-  ExprPtr body = body_func(VarHandle(args[0]), VarHandle(args[1])).node();
-  BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+  ExprHandle body = body_func(args[0], args[1]);
+  BufHandle buf = Buf::make(name, dims, body.dtype());
   return Tensor(buf, args, body);
 }
 
@@ -102,13 +102,11 @@
   if (dim_args.size() != 3) {
     throw malformed_input("mismatch between body and arg size (3)");
   }
-  std::vector<ExprPtr> dims;
-  std::vector<VarPtr> args;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> args;
   unpack_dim_args(dim_args, &dims, &args);
-  ExprPtr body =
-      body_func(VarHandle(args[0]), VarHandle(args[1]), VarHandle(args[2]))
-          .node();
-  BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+  ExprHandle body = body_func(args[0], args[1], args[2]);
+  BufHandle buf = Buf::make(name, dims, body.dtype());
   return Tensor(buf, args, body);
 }
 
@@ -123,16 +121,11 @@
   if (dim_args.size() != 4) {
     throw malformed_input("mismatch between body and arg size (4)");
   }
-  std::vector<ExprPtr> dims;
-  std::vector<VarPtr> args;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> args;
   unpack_dim_args(dim_args, &dims, &args);
-  ExprPtr body = body_func(
-                     VarHandle(args[0]),
-                     VarHandle(args[1]),
-                     VarHandle(args[2]),
-                     VarHandle(args[3]))
-                     .node();
-  BufPtr buf = alloc<Buf>(name, dims, body->dtype());
+  ExprHandle body = body_func(args[0], args[1], args[2], args[3]);
+  BufHandle buf = Buf::make(name, dims, body.dtype());
   return Tensor(buf, args, body);
 }
 
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 1824463..d2cfe0a 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -19,6 +19,8 @@
       : buf_(buf) {
     stmt_ = constructStmt(args, body, {}, {});
   }
+  Tensor(BufHandle buf, const std::vector<VarHandle>& args, ExprHandle body)
+      : Tensor(buf.node(), VarHandleVectorToVarVector(args), body.node()) {}
 
   // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
   Tensor(
@@ -30,6 +32,18 @@
       : buf_(buf) {
     stmt_ = constructStmt(args, body, reduce_dims, reduce_args);
   }
+  Tensor(
+      BufHandle buf,
+      const std::vector<VarHandle>& args,
+      const std::vector<ExprHandle>& reduce_dims,
+      const std::vector<VarHandle>& reduce_args,
+      ExprHandle body)
+      : Tensor(
+            buf.node(),
+            VarHandleVectorToVarVector(args),
+            ExprHandleVectorToExprVector(reduce_dims),
+            VarHandleVectorToVarVector(reduce_args),
+            body.node()) {}
 
   Tensor(BufPtr buf, StmtPtr stmt) : buf_(buf), stmt_(stmt) {}
 
@@ -87,16 +101,16 @@
 
 inline void unpack_dim_args(
     const std::vector<DimArg>& dim_args,
-    std::vector<ExprPtr>* dims,
-    std::vector<VarPtr>* vars) {
+    std::vector<ExprHandle>* dims,
+    std::vector<VarHandle>* vars) {
   dims->clear();
   vars->clear();
   for (const DimArg& dim_arg : dim_args) {
-    ExprPtr expr = dim_arg.dim().node();
+    ExprHandle expr = dim_arg.dim();
     dims->push_back(expr);
-    vars->push_back(alloc<Var>(
+    vars->push_back(VarHandle(alloc<Var>(
         dim_arg.name_hint(),
-        expr->dtype().scalar_type() == ScalarType::Long ? kLong : kInt));
+        expr.dtype().scalar_type() == ScalarType::Long ? kLong : kInt)));
   }
 }
 
@@ -109,47 +123,31 @@
     const InitFunc& init_func,
     const BodyFunc& body_func,
     const std::vector<DimArg>& reduce_args) {
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<ExprPtr> dims;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<VarPtr> vars;
+  std::vector<ExprHandle> dims;
+  std::vector<VarHandle> vars;
   unpack_dim_args(dim_args, &dims, &vars);
 
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<ExprPtr> reduce_dims;
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<VarPtr> reduce_vars;
+  std::vector<ExprHandle> reduce_dims;
+  std::vector<VarHandle> reduce_vars;
   unpack_dim_args(reduce_args, &reduce_dims, &reduce_vars);
 
   // If reduce_vars is empty, then it's not a reduction, but rather a simple
   // copy
   if (reduce_vars.empty()) {
-    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    ExprPtr body =
-        Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(vars))
-            .node();
-    // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-    BufPtr func_result = alloc<Buf>(func_name, dims, body->dtype());
+    ExprHandle body = Reducer::getReduceBody(body_func, vars);
+    BufHandle func_result = Buf::make(func_name, dims, body.dtype());
     return Tensor(func_result, vars, body);
   }
 
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<VarPtr> all_vars;
+  std::vector<VarHandle> all_vars;
   all_vars.insert(all_vars.end(), vars.begin(), vars.end());
   all_vars.insert(all_vars.end(), reduce_vars.begin(), reduce_vars.end());
 
-  ExprHandle body =
-      Reducer::getReduceBody(body_func, VarVectorToVarHandleVector(all_vars));
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  std::vector<ExprPtr> output_args(vars.begin(), vars.end());
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  ExprPtr init_expr = alloc<Cast>(
-      body.dtype(), init_func(VarVectorToVarHandleVector(vars)).node());
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  BufPtr func_result = alloc<Buf>(func_name, dims, body.dtype(), init_expr);
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
-  ReduceOpPtr reduce_op = reducer(func_result, body, output_args, reduce_vars);
-  // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
+  ExprHandle body = Reducer::getReduceBody(body_func, all_vars);
+  std::vector<ExprHandle> output_args(vars.begin(), vars.end());
+  ExprHandle init_expr = Cast::make(body.dtype(), init_func(vars));
+  BufHandle func_result = Buf::make(func_name, dims, body.dtype(), init_expr);
+  ExprHandle reduce_op = reducer(func_result, body, output_args, reduce_vars);
   Tensor t = Tensor(func_result, vars, reduce_dims, reduce_vars, reduce_op);
   return t;
 }