reduce inline autodiff threshold so we can caputre smaller fusions (#57062)
Summary:
This should let us fuse simpler expressions like
```cpp
torch.jit.script
def foo(x):
return torch.sigmoid(torch.sigmoid(x))
```
RUN_TORCHBENCH: alexnet attention_is_all_you_need_pytorch Background_Matting BERT_pytorch demucs densenet121 dlrm fastNLP gen_torchvision_benchmarks.py LearningToPaint maml mnasnet1_0 mobilenet_v2 mobilenet_v2_quantized_qat moco pyhpc_equation_of_state pyhpc_isoneutral_mixing pytorch_CycleGAN_and_pix2pix pytorch_mobilenet_v3 pytorch_stargan pytorch_struct resnet18 resnet50 resnext50_32x4d shufflenet_v2_x1_0 squeezenet1_1 Super_SloMo tacotron2 vgg16 yolov3
Pull Request resolved: https://github.com/pytorch/pytorch/pull/57062
Reviewed By: zou3519
Differential Revision: D28053608
Pulled By: Krovatkin
fbshipit-source-id: 6871c3d2a81dd326a481e7ecfaf2ffefffce4a89
diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py
index 2a89931..71d07ae 100644
--- a/test/jit/test_autodiff_subgraph_slicing.py
+++ b/test/jit/test_autodiff_subgraph_slicing.py
@@ -1,7 +1,8 @@
import os
import sys
import unittest
-from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, enable_profiling_mode_for_profiling_tests
+from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, \
+ num_profiled_runs, enable_profiling_mode_for_profiling_tests
from torch.testing._internal.common_jit import check_against_reference
import torch
@@ -51,6 +52,35 @@
output = func(input, profile_and_replay=True)
self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
+
+ @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "This threshold is only valid for Profiling Executor")
+ def test_diff_graph_inline_threshold(self):
+ with enable_profiling_mode_for_profiling_tests():
+ NUM_RUNS = 1
+ with num_profiled_runs(NUM_RUNS):
+ @torch.jit.script
+ def foo(x):
+
+ # two nodes should be fused
+ # see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
+ return torch.sigmoid(torch.sigmoid(x))
+
+ @torch.jit.script
+ def bar(x):
+ # two nodes should NOT be fused
+ return torch.sigmoid(x)
+
+ input = torch.rand([4, 4], requires_grad=True)
+ foo(input)
+ foo(input)
+
+ bar(input)
+ bar(input)
+
+ print(foo.graph_for(input))
+ self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1)
+ self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0)
+
def test_bias_as_module_attr(self):
with enable_profiling_mode_for_profiling_tests():
diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
index 0170913..90cdb68 100644
--- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
+++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
@@ -477,7 +477,7 @@
}
InlineAutodiffSubgraphs(
copy,
- getAutodiffSubgraphInlining() ? autodiffSubgraphInlineThreshold : 1);
+ getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
replaceFallbackGraphWithFallbackFunction(copy->block());
RemoveProfilingNodes(copy);
GRAPH_DEBUG(