[TensorExpr] Factor out LoopNest::insertAllocFree. (#35175)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/35175
Differential Revision: D20585576
Test Plan: Imported from OSS
Pulled By: ZolotukhinM
fbshipit-source-id: 498b7ddf44df11392f6b5454387a29c5457bdb05
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 415b782..9893310 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -678,24 +678,11 @@
inlined_random_functions_.insert(stmt_to_tensor_.at(s)->function());
}
-void LoopNest::prepareForCodegen() {
- // TODO: check if `s` is a body of a loop
- std::vector<Function*> inlined_functions_vec(
- inlined_functions_.begin(), inlined_functions_.end());
- std::vector<Function*> inlined_randoms_vec(
- inlined_random_functions_.begin(), inlined_random_functions_.end());
- root_stmt_ = InjectInlines(root_stmt_, inlined_functions_vec);
- root_stmt_ = InlineRandom(root_stmt_, inlined_randoms_vec);
-
- // Flatten function calls.
- Flattener flattener;
- Stmt* core_stmt = root_stmt_->accept_mutator(&flattener);
-
+Stmt* LoopNest::insertAllocFree(Stmt* stmt) {
// Add allocs and frees for intermediate buffers at the global level.
// TODO: move allocs and frees to the imemediate areas to reuse buffers.
if (intermediate_tensors_.size() == 0ULL) {
- root_stmt_ = core_stmt;
- return;
+ return stmt;
}
std::vector<Stmt*> allocs;
std::vector<Stmt*> frees;
@@ -720,8 +707,24 @@
std::reverse(frees.begin(), frees.end());
Stmt* alloc_block = Block::make(allocs);
Stmt* free_block = Block::make(frees);
- Stmt* combined_stmt = Block::make({alloc_block, core_stmt, free_block});
- root_stmt_ = combined_stmt;
+ Stmt* combined_stmt = Block::make({alloc_block, stmt, free_block});
+ return combined_stmt;
+}
+
+void LoopNest::prepareForCodegen() {
+ std::vector<Function*> inlined_functions_vec(
+ inlined_functions_.begin(), inlined_functions_.end());
+ std::vector<Function*> inlined_randoms_vec(
+ inlined_random_functions_.begin(), inlined_random_functions_.end());
+ root_stmt_ = InjectInlines(root_stmt_, inlined_functions_vec);
+ root_stmt_ = InlineRandom(root_stmt_, inlined_randoms_vec);
+
+ // Flatten function calls.
+ Flattener flattener;
+ root_stmt_ = root_stmt_->accept_mutator(&flattener);
+
+ // Add allocs and frees for intermediate buffers at the global level.
+ root_stmt_ = insertAllocFree(root_stmt_);
}
void LoopNest::splitWithTail(
diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h
index 6fa3d97..4add2f3 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.h
+++ b/torch/csrc/jit/tensorexpr/loopnest.h
@@ -45,6 +45,7 @@
std::vector<Tensor*> findAllNeededTensors(
const std::vector<Tensor*>& tensors);
Stmt* lowerToStmt(Tensor* t);
+ Stmt* insertAllocFree(Stmt* stmt);
std::unordered_set<Function*> inlined_functions_;
std::unordered_set<Function*> inlined_random_functions_;