reduce inline autodiff threshold so we can caputre smaller fusions (#57062)

Summary:
This should let us fuse simpler expressions like

```cpp
              torch.jit.script
                def foo(x):
                    return torch.sigmoid(torch.sigmoid(x))
```

RUN_TORCHBENCH: alexnet attention_is_all_you_need_pytorch Background_Matting BERT_pytorch demucs densenet121 dlrm fastNLP gen_torchvision_benchmarks.py LearningToPaint maml mnasnet1_0 mobilenet_v2 mobilenet_v2_quantized_qat moco pyhpc_equation_of_state pyhpc_isoneutral_mixing pytorch_CycleGAN_and_pix2pix pytorch_mobilenet_v3 pytorch_stargan pytorch_struct resnet18 resnet50 resnext50_32x4d shufflenet_v2_x1_0 squeezenet1_1 Super_SloMo tacotron2 vgg16 yolov3

Pull Request resolved: https://github.com/pytorch/pytorch/pull/57062

Reviewed By: zou3519

Differential Revision: D28053608

Pulled By: Krovatkin

fbshipit-source-id: 6871c3d2a81dd326a481e7ecfaf2ffefffce4a89
diff --git a/test/jit/test_autodiff_subgraph_slicing.py b/test/jit/test_autodiff_subgraph_slicing.py
index 2a89931..71d07ae 100644
--- a/test/jit/test_autodiff_subgraph_slicing.py
+++ b/test/jit/test_autodiff_subgraph_slicing.py
@@ -1,7 +1,8 @@
 import os
 import sys
 import unittest
-from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, enable_profiling_mode_for_profiling_tests
+from torch.testing._internal.common_utils import GRAPH_EXECUTOR, ProfilingMode, \
+    num_profiled_runs, enable_profiling_mode_for_profiling_tests
 from torch.testing._internal.common_jit import check_against_reference
 import torch
 
@@ -51,6 +52,35 @@
                 output = func(input, profile_and_replay=True)
                 self.assertAutodiffNode(func.graph_for(input), True, ['prim::ConstantChunk'], [])
 
+
+    @unittest.skipIf(GRAPH_EXECUTOR != ProfilingMode.PROFILING, "This threshold is only valid for Profiling Executor")
+    def test_diff_graph_inline_threshold(self):
+        with enable_profiling_mode_for_profiling_tests():
+            NUM_RUNS = 1
+            with num_profiled_runs(NUM_RUNS):
+                @torch.jit.script
+                def foo(x):
+
+                    #  two nodes should be fused
+                    #  see https://github.com/pytorch/pytorch/blob/master/torch/csrc/jit/runtime/graph_executor_impl.h#L49
+                    return torch.sigmoid(torch.sigmoid(x))
+
+                @torch.jit.script
+                def bar(x):
+                    #  two nodes should NOT be fused
+                    return torch.sigmoid(x)
+
+                input = torch.rand([4, 4], requires_grad=True)
+                foo(input)
+                foo(input)
+
+                bar(input)
+                bar(input)
+
+                print(foo.graph_for(input))
+                self.assertGraphContainsExactly(foo.graph_for(input), 'prim::DifferentiableGraph', 1)
+                self.assertGraphContainsExactly(bar.graph_for(input), 'prim::DifferentiableGraph', 0)
+
     def test_bias_as_module_attr(self):
 
         with enable_profiling_mode_for_profiling_tests():
diff --git a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
index 0170913..90cdb68 100644
--- a/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
+++ b/torch/csrc/jit/runtime/profiling_graph_executor_impl.cpp
@@ -477,7 +477,7 @@
     }
     InlineAutodiffSubgraphs(
         copy,
-        getAutodiffSubgraphInlining() ? autodiffSubgraphInlineThreshold : 1);
+        getAutodiffSubgraphInlining() ? autodiffSubgraphNodeThreshold : 1);
     replaceFallbackGraphWithFallbackFunction(copy->block());
     RemoveProfilingNodes(copy);
     GRAPH_DEBUG(