[TensorExpr] Move 'lowerToStmt' method from 'LoopNest' to 'Tensor'. (#50994)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50994

Eventually, 'Tensor' will be fully responsible for its 'Stmt' and moving
this method to it is one step in that direction.

Test Plan: Imported from OSS

Reviewed By: bertmaher

Differential Revision: D26038222

Pulled By: ZolotukhinM

fbshipit-source-id: 0549f0ae6b46a93ff7608a22e79faa5115eef661
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index b189f04..6e63743 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -508,7 +508,7 @@
 
   std::vector<Stmt*> loops;
   for (Tensor* t : tensors_to_compute) {
-    Stmt* loop = lowerToStmt(t);
+    Stmt* loop = t->lowerToStmt();
     // Flatten initializers.
     if (Block* block = dynamic_cast<Block*>(loop)) {
       for (auto* s : block->stmts()) {
@@ -532,46 +532,6 @@
   }
 }
 
-Stmt* LoopNest::lowerToStmt(Tensor* t) {
-  Stmt* body = t->ElementStmt();
-
-  // If this Tensor has no functional body, it already has its axes expanded.
-  if (nullptr == t->body()) {
-    return body;
-  }
-
-  if (t->ndim() == 0 && t->reduce_ndim() == 0) {
-    return body;
-  }
-
-  const Expr* init_expr = t->buf()->initializer();
-
-  std::vector<const Expr*> indices(t->args().begin(), t->args().end());
-
-  if (t->reduce_ndim() > 0) {
-    for (size_t i = 0; i < t->reduce_ndim(); i++) {
-      // Going in reverse order: from innermost loop to the outermost
-      size_t dim_index = t->reduce_ndim() - i - 1;
-      body = new For(
-          t->reduce_arg(dim_index),
-          new IntImm(0),
-          t->reduce_dim(dim_index),
-          body);
-    }
-    if (init_expr) {
-      Store* init = new Store(t->buf(), indices, init_expr, new IntImm(1));
-      body = new Block({init, body});
-    }
-  }
-
-  for (size_t i = 0; i < t->ndim(); i++) {
-    // Going in reverse order: from innermost loop to the outermost
-    size_t dim_index = t->ndim() - i - 1;
-    body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body);
-  }
-  return body;
-}
-
 class FunctionInliner : public IRMutator {
  public:
   FunctionInliner(Store* producer, std::unordered_set<const Buf*> outputs)
diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h
index 6c016f9..1bff2e9 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.h
+++ b/torch/csrc/jit/tensorexpr/loopnest.h
@@ -122,7 +122,6 @@
  private:
   std::vector<Tensor*> findAllNeededTensors(
       const std::vector<Tensor*>& tensors);
-  Stmt* lowerToStmt(Tensor* t);
   Stmt* insertAllocFree(Stmt* stmt);
 
   Stmt* root_stmt_;
diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp
index d12f699..e32f9f2 100644
--- a/torch/csrc/jit/tensorexpr/tensor.cpp
+++ b/torch/csrc/jit/tensorexpr/tensor.cpp
@@ -8,6 +8,43 @@
 namespace jit {
 namespace tensorexpr {
 
+Stmt* Tensor::lowerToStmt() const {
+  Stmt* s = ElementStmt();
+
+  // If this Tensor has no functional body, it already has its axes expanded.
+  if (nullptr == body()) {
+    return s;
+  }
+
+  if (ndim() == 0 && reduce_ndim() == 0) {
+    return s;
+  }
+
+  const Expr* init_expr = buf()->initializer();
+
+  std::vector<const Expr*> indices(args().begin(), args().end());
+
+  if (reduce_ndim() > 0) {
+    for (size_t i = 0; i < reduce_ndim(); i++) {
+      // Going in reverse order: from innermost loop to the outermost
+      size_t dim_index = reduce_ndim() - i - 1;
+      s = new For(
+          reduce_arg(dim_index), new IntImm(0), reduce_dim(dim_index), s);
+    }
+    if (init_expr) {
+      Store* init = new Store(buf(), indices, init_expr, new IntImm(1));
+      s = new Block({init, s});
+    }
+  }
+
+  for (size_t i = 0; i < ndim(); i++) {
+    // Going in reverse order: from innermost loop to the outermost
+    size_t dim_index = ndim() - i - 1;
+    s = new For(arg(dim_index), new IntImm(0), dim(dim_index), s);
+  }
+  return s;
+}
+
 Tensor* Compute(
     const std::string& func_name,
     const std::vector<DimArg>& dim_args,
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 53e0faa..31454f7 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -99,6 +99,8 @@
   template <typename... Ts>
   inline ExprHandle call(const Ts&... ts);
 
+  Stmt* lowerToStmt() const;
+
  private:
   const Buf* buf_;
   std::vector<const Var*> args_;