[TensorExpr] Factor out LoopNest::insertAllocFree. (#35175)

Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35175

Differential Revision: D20585576

Test Plan: Imported from OSS

Pulled By: ZolotukhinM

fbshipit-source-id: 498b7ddf44df11392f6b5454387a29c5457bdb05
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 415b782..9893310 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -678,24 +678,11 @@
   inlined_random_functions_.insert(stmt_to_tensor_.at(s)->function());
 }
 
-void LoopNest::prepareForCodegen() {
-  // TODO: check if `s` is a body of a loop
-  std::vector<Function*> inlined_functions_vec(
-      inlined_functions_.begin(), inlined_functions_.end());
-  std::vector<Function*> inlined_randoms_vec(
-      inlined_random_functions_.begin(), inlined_random_functions_.end());
-  root_stmt_ = InjectInlines(root_stmt_, inlined_functions_vec);
-  root_stmt_ = InlineRandom(root_stmt_, inlined_randoms_vec);
-
-  // Flatten function calls.
-  Flattener flattener;
-  Stmt* core_stmt = root_stmt_->accept_mutator(&flattener);
-
+Stmt* LoopNest::insertAllocFree(Stmt* stmt) {
   // Add allocs and frees for intermediate buffers at the global level.
   // TODO: move allocs and frees to the imemediate areas to reuse buffers.
   if (intermediate_tensors_.size() == 0ULL) {
-    root_stmt_ = core_stmt;
-    return;
+    return stmt;
   }
   std::vector<Stmt*> allocs;
   std::vector<Stmt*> frees;
@@ -720,8 +707,24 @@
   std::reverse(frees.begin(), frees.end());
   Stmt* alloc_block = Block::make(allocs);
   Stmt* free_block = Block::make(frees);
-  Stmt* combined_stmt = Block::make({alloc_block, core_stmt, free_block});
-  root_stmt_ = combined_stmt;
+  Stmt* combined_stmt = Block::make({alloc_block, stmt, free_block});
+  return combined_stmt;
+}
+
+void LoopNest::prepareForCodegen() {
+  std::vector<Function*> inlined_functions_vec(
+      inlined_functions_.begin(), inlined_functions_.end());
+  std::vector<Function*> inlined_randoms_vec(
+      inlined_random_functions_.begin(), inlined_random_functions_.end());
+  root_stmt_ = InjectInlines(root_stmt_, inlined_functions_vec);
+  root_stmt_ = InlineRandom(root_stmt_, inlined_randoms_vec);
+
+  // Flatten function calls.
+  Flattener flattener;
+  root_stmt_ = root_stmt_->accept_mutator(&flattener);
+
+  // Add allocs and frees for intermediate buffers at the global level.
+  root_stmt_ = insertAllocFree(root_stmt_);
 }
 
 void LoopNest::splitWithTail(
diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h
index 6fa3d97..4add2f3 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.h
+++ b/torch/csrc/jit/tensorexpr/loopnest.h
@@ -45,6 +45,7 @@
   std::vector<Tensor*> findAllNeededTensors(
       const std::vector<Tensor*>& tensors);
   Stmt* lowerToStmt(Tensor* t);
+  Stmt* insertAllocFree(Stmt* stmt);
 
   std::unordered_set<Function*> inlined_functions_;
   std::unordered_set<Function*> inlined_random_functions_;