[TensorExpr] Move 'lowerToStmt' method from 'LoopNest' to 'Tensor'. (#50994)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50994
Eventually, 'Tensor' will be fully responsible for its 'Stmt' and moving
this method to it is one step in that direction.
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D26038222
Pulled By: ZolotukhinM
fbshipit-source-id: 0549f0ae6b46a93ff7608a22e79faa5115eef661
diff --git a/torch/csrc/jit/tensorexpr/loopnest.cpp b/torch/csrc/jit/tensorexpr/loopnest.cpp
index b189f04..6e63743 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.cpp
+++ b/torch/csrc/jit/tensorexpr/loopnest.cpp
@@ -508,7 +508,7 @@
std::vector<Stmt*> loops;
for (Tensor* t : tensors_to_compute) {
- Stmt* loop = lowerToStmt(t);
+ Stmt* loop = t->lowerToStmt();
// Flatten initializers.
if (Block* block = dynamic_cast<Block*>(loop)) {
for (auto* s : block->stmts()) {
@@ -532,46 +532,6 @@
}
}
-Stmt* LoopNest::lowerToStmt(Tensor* t) {
- Stmt* body = t->ElementStmt();
-
- // If this Tensor has no functional body, it already has its axes expanded.
- if (nullptr == t->body()) {
- return body;
- }
-
- if (t->ndim() == 0 && t->reduce_ndim() == 0) {
- return body;
- }
-
- const Expr* init_expr = t->buf()->initializer();
-
- std::vector<const Expr*> indices(t->args().begin(), t->args().end());
-
- if (t->reduce_ndim() > 0) {
- for (size_t i = 0; i < t->reduce_ndim(); i++) {
- // Going in reverse order: from innermost loop to the outermost
- size_t dim_index = t->reduce_ndim() - i - 1;
- body = new For(
- t->reduce_arg(dim_index),
- new IntImm(0),
- t->reduce_dim(dim_index),
- body);
- }
- if (init_expr) {
- Store* init = new Store(t->buf(), indices, init_expr, new IntImm(1));
- body = new Block({init, body});
- }
- }
-
- for (size_t i = 0; i < t->ndim(); i++) {
- // Going in reverse order: from innermost loop to the outermost
- size_t dim_index = t->ndim() - i - 1;
- body = new For(t->arg(dim_index), new IntImm(0), t->dim(dim_index), body);
- }
- return body;
-}
-
class FunctionInliner : public IRMutator {
public:
FunctionInliner(Store* producer, std::unordered_set<const Buf*> outputs)
diff --git a/torch/csrc/jit/tensorexpr/loopnest.h b/torch/csrc/jit/tensorexpr/loopnest.h
index 6c016f9..1bff2e9 100644
--- a/torch/csrc/jit/tensorexpr/loopnest.h
+++ b/torch/csrc/jit/tensorexpr/loopnest.h
@@ -122,7 +122,6 @@
private:
std::vector<Tensor*> findAllNeededTensors(
const std::vector<Tensor*>& tensors);
- Stmt* lowerToStmt(Tensor* t);
Stmt* insertAllocFree(Stmt* stmt);
Stmt* root_stmt_;
diff --git a/torch/csrc/jit/tensorexpr/tensor.cpp b/torch/csrc/jit/tensorexpr/tensor.cpp
index d12f699..e32f9f2 100644
--- a/torch/csrc/jit/tensorexpr/tensor.cpp
+++ b/torch/csrc/jit/tensorexpr/tensor.cpp
@@ -8,6 +8,43 @@
namespace jit {
namespace tensorexpr {
+Stmt* Tensor::lowerToStmt() const {
+ Stmt* s = ElementStmt();
+
+ // If this Tensor has no functional body, it already has its axes expanded.
+ if (nullptr == body()) {
+ return s;
+ }
+
+ if (ndim() == 0 && reduce_ndim() == 0) {
+ return s;
+ }
+
+ const Expr* init_expr = buf()->initializer();
+
+ std::vector<const Expr*> indices(args().begin(), args().end());
+
+ if (reduce_ndim() > 0) {
+ for (size_t i = 0; i < reduce_ndim(); i++) {
+ // Going in reverse order: from innermost loop to the outermost
+ size_t dim_index = reduce_ndim() - i - 1;
+ s = new For(
+ reduce_arg(dim_index), new IntImm(0), reduce_dim(dim_index), s);
+ }
+ if (init_expr) {
+ Store* init = new Store(buf(), indices, init_expr, new IntImm(1));
+ s = new Block({init, s});
+ }
+ }
+
+ for (size_t i = 0; i < ndim(); i++) {
+ // Going in reverse order: from innermost loop to the outermost
+ size_t dim_index = ndim() - i - 1;
+ s = new For(arg(dim_index), new IntImm(0), dim(dim_index), s);
+ }
+ return s;
+}
+
Tensor* Compute(
const std::string& func_name,
const std::vector<DimArg>& dim_args,
diff --git a/torch/csrc/jit/tensorexpr/tensor.h b/torch/csrc/jit/tensorexpr/tensor.h
index 53e0faa..31454f7 100644
--- a/torch/csrc/jit/tensorexpr/tensor.h
+++ b/torch/csrc/jit/tensorexpr/tensor.h
@@ -99,6 +99,8 @@
template <typename... Ts>
inline ExprHandle call(const Ts&... ts);
+ Stmt* lowerToStmt() const;
+
private:
const Buf* buf_;
std::vector<const Var*> args_;