Add Post Freezing Optimizations, turn on by default in torch.jit.freeze (#50222)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50222

This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal.

I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything.

I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations.

Test Plan: Imported from OSS

Reviewed By: tugsbayasgalan

Differential Revision: D25856264

Pulled By: eellison

fbshipit-source-id: 56be1f12cfc459b4c4421d4dfdedff8b9ac77112
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 7591295..bd31a7e 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1501,3 +1501,19 @@
             # add with different dtype
             test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False,
                              add_tensor=torch.rand(1).to(torch.int), expect_success=False)
+
+    def test_optimize_freeze_module(self):
+        in_channels, out_channels = 3, 32
+        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
+        bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
+        mod = torch.nn.Sequential(conv, bn)
+        # set optimize to False here, by default freezing runs optimize_frozen_module
+        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
+        # inspect frozen mod
+        FileCheck().check("batch_norm").run(frozen_mod.graph)
+        torch.jit.optimize_frozen_module(frozen_mod)
+        FileCheck().check_not("batch_norm").run(frozen_mod.graph)
+
+        # optimize_frozen_module should be run
+        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
+        FileCheck().check_not("batch_norm").run(frozen_mod.graph)
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index db1442c..33a38d5 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -201,6 +201,7 @@
     "torch/csrc/jit/passes/prepack_folding.cpp",
     "torch/csrc/jit/passes/fold_conv_bn.cpp",
     "torch/csrc/jit/passes/frozen_conv_folding.cpp",
+    "torch/csrc/jit/passes/frozen_graph_optimizations.cpp",
     "torch/csrc/jit/passes/remove_expands.cpp",
     "torch/csrc/jit/passes/remove_dropout.cpp",
     "torch/csrc/jit/passes/requires_grad_analysis.cpp",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index bbe63f6..bf75dc1 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -170,6 +170,7 @@
                    preserved_attrs: List[str] = [],
                    freeze_interfaces: _bool = True,
                    preserveParameters: _bool = True) -> ScriptModule: ...
+def _jit_pass_optimize_frozen_graph(Graph) -> None: ...
 def _is_tracing() -> _bool: ...
 def _jit_init() -> _bool: ...
 def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
new file mode 100644
index 0000000..7387025
--- /dev/null
+++ b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
@@ -0,0 +1,21 @@
+#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/ir/ir_views.h>
+#include <torch/csrc/jit/passes/frozen_conv_folding.h>
+#include <torch/csrc/jit/runtime/graph_executor.h>
+#include <torch/csrc/utils/memory.h>
+
+namespace torch {
+namespace jit {
+
+void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
+  // run a couple times to capture Conv -> Mul -> Add etc
+  for (size_t i = 0; i < 2; i++) {
+    FoldFrozenConvBatchnorm(graph);
+    FoldFrozenConvAddOrSub(graph);
+    FoldFrozenConvMulOrDiv(graph);
+  }
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.h b/torch/csrc/jit/passes/frozen_graph_optimizations.h
new file mode 100644
index 0000000..a808791
--- /dev/null
+++ b/torch/csrc/jit/passes/frozen_graph_optimizations.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+/** \brief Runs a set of Optimizations that Optimize Frozen Graphs
+ *
+ * Currently this set of optimizations is:
+ * - FoldFrozenConvBatchnorm
+ * - FoldFrozenConvAddOrSub
+ * - FoldFrozenConvMulOrDiv
+ */
+
+namespace torch {
+namespace jit {
+
+TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 75b6831..99a0330 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -23,6 +23,7 @@
 #include <torch/csrc/jit/passes/fold_conv_bn.h>
 #include <torch/csrc/jit/passes/freeze_module.h>
 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
+#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
 #include <torch/csrc/jit/passes/fuse_linear.h>
 #include <torch/csrc/jit/passes/fuse_relu.h>
 #include <torch/csrc/jit/passes/graph_fuser.h>
@@ -299,6 +300,7 @@
       .def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
       .def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
       .def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
+      .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
       .def("_jit_pass_fuse_linear", &FuseLinear)
       .def(
           "_jit_pass_fuse_add_relu",
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index cfd3271..bf36436 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -45,7 +45,7 @@
 from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
 
 from torch.jit.cuda import stream
-from torch.jit._freeze import freeze
+from torch.jit._freeze import freeze, optimize_frozen_module
 
 # For backwards compatibility
 _fork = fork
@@ -93,20 +93,20 @@
     return _script_if_tracing(fn)
 
 
-# for torch.jit.isinstance 
+# for torch.jit.isinstance
 def isinstance(obj, target_type):
     """
-    This function provides for conatiner type refinement in TorchScript. It can refine 
+    This function provides for conatiner type refinement in TorchScript. It can refine
     parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
-    ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also 
+    ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
     refine basic types such as bools and ints that are available in TorchScript.
 
     Args:
         obj: object to refine the type of
-        target_type: type to try to refine obj to 
+        target_type: type to try to refine obj to
     Returns:
-        ``bool``: True if obj was successfully refined to the type of target_type, 
-            False otherwise with no new type refinement     
+        ``bool``: True if obj was successfully refined to the type of target_type,
+            False otherwise with no new type refinement
 
 
     Example (using ``torch.jit.isinstance`` for type refinement):
diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py
index 5db57f0..98e53c6 100644
--- a/torch/jit/_freeze.py
+++ b/torch/jit/_freeze.py
@@ -10,7 +10,7 @@
 from torch.jit._script import RecursiveScriptModule, ScriptModule
 
 
-def freeze(mod, preserved_attrs: Optional[List[str]] = None):
+def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):
     r"""
     Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
     module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
@@ -26,6 +26,11 @@
         preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
         Attributes modified in preserved methods will also be preserved.
 
+        optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference,
+        in addition to the graph cleanup that already occurs. The details of the optimizations can be found in
+        `torch.jit.optimize_frozen_module.`
+
+
     Returns:
         Frozen :class:`ScriptModule`.
 
@@ -97,5 +102,42 @@
 
     out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
     RecursiveScriptModule._finalize_scriptmodule(out)
+    if optimize:
+        optimize_frozen_module(out)
 
     return out
+
+
+def optimize_frozen_module(mod):
+    r"""
+    Runs a series of optimizations looking for patterns that occur in frozen graphs.
+    The current set of optimizations is:
+        - Conv -> Batchnorm folding
+        - Conv -> Add/Sub folding
+        - Conv -> Mul/Div folding
+
+    Args:
+        mod (:class:`ScriptModule`): a frozen module to be optimized
+
+    Returns:
+        None
+
+    Note:
+        In rare occassions, this can result in slower execution.
+
+    Example (Freezing a module with Conv->Batchnorm)
+    .. code-block:: python
+        import torch
+        in_channels, out_channels = 3, 32
+        conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
+        bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
+        mod = torch.nn.Sequential(conv, bn)
+        # set optimize to False here, by default freezing runs optimize_frozen_module
+        frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
+        # inspect frozen mod
+        assert "batch_norm" in str(frozen_mod.graph)
+        torch.jit.optimize_frozen_module(frozen_mod)
+        assert "batch_norm" not in str(frozen_mod.graph)
+
+    """
+    torch._C._jit_pass_optimize_frozen_graph(mod.graph)