[TensorExpr] Add some graph-rewrite passes to prepare models for AOT compilation. (#66515)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/66515
These passes should not be used generally as they change API of the
model's forward method, but they help experimenting with the model and
ironing out all the kinks before it can be compiled properly. In the
long run ideally we should provide a better way to enable such
experiments.
Differential Revision:
D31590862
D31590862
Test Plan: Imported from OSS
Reviewed By: navahgar
Pulled By: ZolotukhinM
fbshipit-source-id: 74ded34c6c871d4cafa29f43dc27c7e71daff8fc
diff --git a/test/cpp/tensorexpr/test_graph_opt.cpp b/test/cpp/tensorexpr/test_graph_opt.cpp
index e5a237f..649bbb2 100644
--- a/test/cpp/tensorexpr/test_graph_opt.cpp
+++ b/test/cpp/tensorexpr/test_graph_opt.cpp
@@ -3,6 +3,8 @@
#include <test/cpp/tensorexpr/test_base.h>
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/ir/irparser.h>
+#include <torch/csrc/jit/passes/lower_tuples.h>
+#include <torch/csrc/jit/tensorexpr/graph_opt.h>
#include <torch/csrc/jit/tensorexpr/kernel.h>
#include <torch/csrc/jit/testing/file_check.h>
#include <torch/torch.h>
@@ -299,5 +301,20 @@
#endif
}
+TEST_F(GraphOpt, AOTGraphPrepPasses) {
+ const auto graph_string = R"IR(
+ graph(%x, %y, %z, %i : int):
+ %xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
+ return (%xyz_list, %i))IR";
+ auto g = std::make_shared<Graph>();
+ torch::jit::parseIR(graph_string, g.get());
+
+ removeGraphOutput(g, 1);
+ replaceListOutputWithTuple(g);
+ LowerAllTuples(g);
+
+ testing::FileCheck().check("return (%x, %y, %z)")->run(*g);
+}
+
} // namespace jit
} // namespace torch
diff --git a/test/cpp/tensorexpr/test_reductions.cpp b/test/cpp/tensorexpr/test_reductions.cpp
index 27684b5..dc0ecce 100644
--- a/test/cpp/tensorexpr/test_reductions.cpp
+++ b/test/cpp/tensorexpr/test_reductions.cpp
@@ -1550,7 +1550,6 @@
StmtPtr result = IRSimplifier::simplify(l.root_stmt());
SimpleIREvaluator cg(result, {a, b, e});
- std::cout << *cg.stmt() << std::endl;
std::ostringstream oss;
oss << *cg.stmt();
const std::string& expected_ir =
diff --git a/torch/csrc/jit/tensorexpr/graph_opt.cpp b/torch/csrc/jit/tensorexpr/graph_opt.cpp
index 6d5ac13..8891bd9 100644
--- a/torch/csrc/jit/tensorexpr/graph_opt.cpp
+++ b/torch/csrc/jit/tensorexpr/graph_opt.cpp
@@ -402,6 +402,26 @@
}
}
+std::shared_ptr<Graph> removeGraphOutput(
+ const std::shared_ptr<Graph>& graph,
+ size_t idx) {
+ graph->eraseOutput(idx);
+ return graph;
+}
+
+std::shared_ptr<Graph> replaceListOutputWithTuple(
+ const std::shared_ptr<Graph>& graph) {
+ auto out = graph->outputs()[0];
+ auto out_node = out->node();
+ if (out_node->kind() != prim::ListConstruct) {
+ return graph;
+ }
+ auto tuple_node = graph->createTuple(out_node->inputs());
+ tuple_node->insertAfter(out_node);
+ out->replaceAllUsesWith(tuple_node->output());
+ return graph;
+}
+
} // namespace tensorexpr
} // namespace jit
} // namespace torch
diff --git a/torch/csrc/jit/tensorexpr/graph_opt.h b/torch/csrc/jit/tensorexpr/graph_opt.h
index e365830..257255e 100644
--- a/torch/csrc/jit/tensorexpr/graph_opt.h
+++ b/torch/csrc/jit/tensorexpr/graph_opt.h
@@ -63,6 +63,11 @@
const std::vector<c10::optional<at::Tensor>>& example_inputs);
TORCH_API std::shared_ptr<Graph> removeUnusedSelfArgument(
const std::shared_ptr<Graph>& graph);
+TORCH_API std::shared_ptr<Graph> removeGraphOutput(
+ const std::shared_ptr<Graph>& graph,
+ size_t idx);
+TORCH_API std::shared_ptr<Graph> replaceListOutputWithTuple(
+ const std::shared_ptr<Graph>& graph);
// Scan all values in the given graph and replace each dimension with a size Xi
// present in \p SIZES with a symbolic shape Yi. Return a vector of symbol
diff --git a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
index cf462ed..717ce05 100644
--- a/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
+++ b/torch/csrc/jit/tensorexpr/tensorexpr_init.cpp
@@ -890,6 +890,10 @@
te.def("make_shapes_symbolic", &tensorexpr::makeShapesSymbolic);
te.def("is_graph_compilable", &tensorexpr::isGraphCompilable);
te.def("fixup_missing_shape_info", &tensorexpr::fixupMissingShapeInfo);
+ te.def("remove_graph_output", &tensorexpr::removeGraphOutput);
+ te.def(
+ "replace_list_output_with_tuple",
+ &tensorexpr::replaceListOutputWithTuple);
#ifdef TORCH_ENABLE_LLVM
te.def("set_llvm_target_triple", [](const c10::optional<std::string>& val) {
tensorexpr::LLVMTargetTriple() = val;