[TensorExpr Fuser] Add support for nodes which have tensor constant inputs (#47814)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/47814
Previously, we would bail completely if a node had a constant tensor input. This PR adds support for this case by lifting the constant out of the fusion graph after we've done fusion. It might be nice to add support for Tensor Constants in NNC itself, but it looked kind of tricky and this is an easy enough temporary solution.
Test Plan: Imported from OSS
Reviewed By: bertmaher
Differential Revision: D25286215
Pulled By: eellison
fbshipit-source-id: 9ff67f92f5a2d43fd3ca087569898666525ca8cf
diff --git a/test/jit/test_profiler.py b/test/jit/test_profiler.py
index e42f822..e763a73 100644
--- a/test/jit/test_profiler.py
+++ b/test/jit/test_profiler.py
@@ -25,7 +25,8 @@
self.default_dtype = torch.get_default_dtype()
self.old_reduction_enabled = torch._C._jit_set_texpr_reductions_enabled(True)
torch.set_default_dtype(torch.double)
-
+ self.old_fusion_inlining = torch._C._debug_get_fusion_group_inlining()
+ torch._C._debug_set_fusion_group_inlining(False)
def tearDown(self):
torch._C._jit_set_profiling_executor(self.prev_exec)
@@ -35,6 +36,7 @@
torch._C._jit_override_can_fuse_on_cpu(self.can_fuse_on_cpu)
torch.set_default_dtype(self.default_dtype)
torch._C._jit_set_texpr_reductions_enabled(self.old_reduction_enabled)
+ torch._C._debug_set_fusion_group_inlining(self.old_fusion_inlining)
def test_tensor_type_not_determined_by_inputs(self):
@torch.jit.script
@@ -212,6 +214,19 @@
g = torch.jit.last_executed_optimized_graph()
FileCheck().check("fallback_function").check_next("CallFunction").run(g)
+ def test_tensor_constant(self):
+ def foo(a, b):
+ return a + b + torch.tensor([2])
+
+ x = torch.ones(1, requires_grad=False)
+ foo_script = torch.jit.script(foo)
+ foo_script(x, x)
+ foo_script(x, x)
+
+ self.assertEqual(foo_script(x, x), foo(x, x))
+ g = torch.jit.last_executed_optimized_graph()
+ FileCheck().check_count("aten::add", 2, exactly=True).run(g)
+
def test_iterative_fusion(self):
@torch.jit.script
def foo(a, b, c, d):
diff --git a/torch/csrc/jit/passes/tensorexpr_fuser.cpp b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
index 917d88a..6f587b9 100644
--- a/torch/csrc/jit/passes/tensorexpr_fuser.cpp
+++ b/torch/csrc/jit/passes/tensorexpr_fuser.cpp
@@ -459,7 +459,7 @@
// fusion is done.
inlineSmallFusionGroups(graph_->block());
GRAPH_DUMP("After inlining small fusion groups: ", graph_);
- guardFusionGroupsAndRemoveOutputs(graph_->block());
+ prepareFusionGroupAndGuardOutputs(graph_->block());
GRAPH_DUMP("After guarding fusion groups: ", graph_);
removeTensorTypeSpecializations(graph_->block());
GRAPH_DUMP("After removing tensor type specializations: ", graph_);
@@ -763,17 +763,10 @@
}
bool canHandle(Node* node) {
- REQ(node->kind() != prim::Constant);
REQ(disable_shape_checks_ || allShapesAreKnown(node));
REQ(isFusableOnDevice(node));
- // Don't include nodes whose inputs are tensor constants - we cannot handle
- // them at the moment.
- // TODO: actually support tensor constants and remove this.
for (Value* input : node->inputs()) {
- if (input->node()->kind() == prim::Constant) {
- REQ(!input->type()->cast<TensorType>())
- }
if (auto const& tt = input->type()->cast<TensorType>()) {
auto st = tt->scalarType();
if (!st) {
@@ -975,11 +968,32 @@
}
}
- void guardFusionGroupsAndRemoveOutputs(Block* block) {
+ // TODO: support constant tensors instead of setting them as input
+ void liftTensorConstantsFromFusionGroups(Node* fusion_group) {
+ auto subgraph = SubgraphUtils::getSubgraph(fusion_group);
+ WithInsertPoint guard(fusion_group);
+ for (auto it = subgraph->block()->nodes().begin();
+ it != subgraph->block()->nodes().end();
+ ++it) {
+ auto n = *it;
+ if (n->kind() == prim::Constant &&
+ n->output()->type()->cast<TensorType>()) {
+ auto constant =
+ fusion_group->owningGraph()->insertConstant(*toIValue(n->output()));
+ fusion_group->addInput(constant);
+ auto inputToGraph = subgraph->addInput();
+ inputToGraph->setType(n->output()->type());
+ n->output()->replaceAllUsesWith(inputToGraph);
+ it.destroyCurrent();
+ }
+ }
+ }
+
+ void prepareFusionGroupAndGuardOutputs(Block* block) {
std::vector<Node*> fusion_groups;
for (Node* n : block->nodes()) {
for (Block* b : n->blocks()) {
- guardFusionGroupsAndRemoveOutputs(b);
+ prepareFusionGroupAndGuardOutputs(b);
}
if (n->kind() == prim::TensorExprGroup) {
fusion_groups.push_back(n);
@@ -987,6 +1001,7 @@
}
for (Node* fusion_group : fusion_groups) {
removeOutputsUsedOnlyInSize(fusion_group);
+ liftTensorConstantsFromFusionGroups(fusion_group);
guardFusionGroup(fusion_group);
}
}