[TensorExpr] Simplify TE IR before applying any transformations. (#64717)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/64717
This also exposed several bugs, which are fixed in this PR.
Differential Revision:
D30826408
D30826408
Test Plan: Imported from OSS
Reviewed By: navahgar
Pulled By: ZolotukhinM
fbshipit-source-id: a67ec5739aceed9ffdf0d24f77eb3787cefe4560
diff --git a/test/cpp/tensorexpr/test_kernel.cpp b/test/cpp/tensorexpr/test_kernel.cpp
index 1a6d086..0bcef07 100644
--- a/test/cpp/tensorexpr/test_kernel.cpp
+++ b/test/cpp/tensorexpr/test_kernel.cpp
@@ -248,7 +248,11 @@
TensorExprKernel k(graph);
std::ostringstream oss;
oss << *k.getCodeGenStmt();
- const std::string& verification_pattern = "# CHECK: 4000000000";
+ // The 4000000000 iterations loop will be split into 500000000 x 8 and the
+ // outer loop will be parallel. If LLVM is not present, it will not be split,
+ // and to cover both of these cases we're looking for 00000000ll; in the
+ // output.
+ const std::string& verification_pattern = R"IR(# CHECK: 00000000ll;)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
}
@@ -629,17 +633,15 @@
const std::string& verification_pattern =
R"IR(
# CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat
# CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat
# CHECK: for
-# CHECK-NEXT: for
-# CHECK-NEXT: for
-# CHECK-NEXT: aten_cat)IR";
+# CHECK: aten_cat
+# CHECK: for
+# CHECK: for
+# CHECK: aten_cat
+# CHECK: for
+# CHECK: for
+# CHECK: aten_cat)IR";
torch::jit::testing::FileCheck().run(verification_pattern, oss.str());
auto a = at::rand({5, 3, 2}, TensorOptions(kCPU).dtype(at::kFloat));
diff --git a/torch/csrc/jit/tensorexpr/ir.h b/torch/csrc/jit/tensorexpr/ir.h
index 65a362e..4448ea4 100644
--- a/torch/csrc/jit/tensorexpr/ir.h
+++ b/torch/csrc/jit/tensorexpr/ir.h
@@ -902,11 +902,6 @@
IntrinsicsOp op_type_;
};
-class Polynomial;
-class Term;
-class MaxTerm;
-class MinTerm;
-
TORCH_API std::vector<ExprPtr> ExprHandleVectorToExprVector(
const std::vector<ExprHandle>&);
TORCH_API std::vector<ExprHandle> ExprVectorToExprHandleVector(
diff --git a/torch/csrc/jit/tensorexpr/kernel.cpp b/torch/csrc/jit/tensorexpr/kernel.cpp
index 17e4c96..cfb27ee 100644
--- a/torch/csrc/jit/tensorexpr/kernel.cpp
+++ b/torch/csrc/jit/tensorexpr/kernel.cpp
@@ -2551,6 +2551,9 @@
}
loopsToFuse.push_back(loop);
}
+ if (loopsToFuse.empty()) {
+ return;
+ }
if (!loopBoundsAllEqual(loopsToFuse)) {
return;
}
@@ -2658,6 +2661,8 @@
auto root_stmt = l.root_stmt();
root_stmt->accept(block_analysis.get());
}
+ l.simplify();
+ GRAPH_DEBUG("after simplify", *l.root_stmt());
// Inlining output & intermediate buffers can duplicate computation.
// Duplicating work can slow down the program if it's not ameliorated in some
@@ -3030,6 +3035,7 @@
// cur_idx = absolute // stride
// absolute = absolute % stride
+ auto zero = LongImm::make(0);
return Compute(
"output_1", dims, [&](const std::vector<VarHandle>& axes_input) {
std::vector<ExprHandle> axes(axes_input.begin(), axes_input.end());
@@ -3042,17 +3048,17 @@
reverse_sort_indices(strides);
std::vector<ExprHandle> new_axes(sorted_stride_indices.size());
for (size_t stride_index : sorted_stride_indices) {
- auto stride = strides[stride_index];
auto size = sizes[stride_index];
- auto index = absolute_position /
- ExprHandle(immLike(absolute_position, stride));
+ auto index = zero;
if (size != 1) {
+ auto stride = strides[stride_index];
+ index = absolute_position /
+ ExprHandle(immLike(absolute_position, stride));
absolute_position = absolute_position %
ExprHandle(immLike(absolute_position, stride));
}
new_axes[stride_index] = index;
}
- // NOLINTNEXTLINE(clang-analyzer-cplusplus.NewDeleteLeaks)
return BufHandle(buf).load(new_axes);
});
}
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index 0750b34..3742b93 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -747,6 +747,9 @@
bool LoopNest::computeInline(BufPtr b) {
// If buf is used or defined in an ExternalCall, we cannot inline it
auto buf_load_store_uses = findLoadOrStoreUses(root_stmt_);
+ if (!buf_load_store_uses.count(b)) {
+ return false;
+ }
for (auto& use : buf_load_store_uses.at(b)) {
StmtPtr s = use.s;
if (to<ExternalCall>(s)) {