Add Post Freezing Optimizations, turn on by default in torch.jit.freeze (#50222)
Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/50222
This PR adds a pass which runs a set of optimizations to be done after freezing. Currently this encompasses Conv-BN folding, Conv->Add/Sub/Mul/Div folding and i'm also planning on adding dropout removal.
I would like some feedback on the API. torch.jit.freeze is technically in \~prototype\~ phase so we have some leeway around making changes. I think in the majority of cases, the user is going to want to freeze their model, and then run in inference. I would prefer if the optimization was opt-out instead of opt-in. All internal/framework use cases of freezing all use `freeze_module`, not the python API, so this shouldn't break anything.
I have separated out the optimization pass as a separate API to make things potentially modular, even though I suspect that is an unlikely case. In a future PR i would like to add a `torch::jit::freeze` which follows the same api as `torch.jit.freeze` intended for C++ use, and runs the optimizations.
Test Plan: Imported from OSS
Reviewed By: tugsbayasgalan
Differential Revision: D25856264
Pulled By: eellison
fbshipit-source-id: 56be1f12cfc459b4c4421d4dfdedff8b9ac77112
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 7591295..bd31a7e 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1501,3 +1501,19 @@
# add with different dtype
test_conv_fusion(use_bias, nn.Conv2d, False, pytorch_op, False,
add_tensor=torch.rand(1).to(torch.int), expect_success=False)
+
+ def test_optimize_freeze_module(self):
+ in_channels, out_channels = 3, 32
+ conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
+ bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
+ mod = torch.nn.Sequential(conv, bn)
+ # set optimize to False here, by default freezing runs optimize_frozen_module
+ frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
+ # inspect frozen mod
+ FileCheck().check("batch_norm").run(frozen_mod.graph)
+ torch.jit.optimize_frozen_module(frozen_mod)
+ FileCheck().check_not("batch_norm").run(frozen_mod.graph)
+
+ # optimize_frozen_module should be run
+ frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
+ FileCheck().check_not("batch_norm").run(frozen_mod.graph)
diff --git a/tools/build_variables.bzl b/tools/build_variables.bzl
index db1442c..33a38d5 100644
--- a/tools/build_variables.bzl
+++ b/tools/build_variables.bzl
@@ -201,6 +201,7 @@
"torch/csrc/jit/passes/prepack_folding.cpp",
"torch/csrc/jit/passes/fold_conv_bn.cpp",
"torch/csrc/jit/passes/frozen_conv_folding.cpp",
+ "torch/csrc/jit/passes/frozen_graph_optimizations.cpp",
"torch/csrc/jit/passes/remove_expands.cpp",
"torch/csrc/jit/passes/remove_dropout.cpp",
"torch/csrc/jit/passes/requires_grad_analysis.cpp",
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index bbe63f6..bf75dc1 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -170,6 +170,7 @@
preserved_attrs: List[str] = [],
freeze_interfaces: _bool = True,
preserveParameters: _bool = True) -> ScriptModule: ...
+def _jit_pass_optimize_frozen_graph(Graph) -> None: ...
def _is_tracing() -> _bool: ...
def _jit_init() -> _bool: ...
def _jit_flatten(arg: Any) -> Tuple[List[Tensor], IODescriptor]: ...
diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
new file mode 100644
index 0000000..7387025
--- /dev/null
+++ b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
@@ -0,0 +1,21 @@
+#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
+#include <torch/csrc/jit/ir/alias_analysis.h>
+#include <torch/csrc/jit/ir/ir_views.h>
+#include <torch/csrc/jit/passes/frozen_conv_folding.h>
+#include <torch/csrc/jit/runtime/graph_executor.h>
+#include <torch/csrc/utils/memory.h>
+
+namespace torch {
+namespace jit {
+
+void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph) {
+ // run a couple times to capture Conv -> Mul -> Add etc
+ for (size_t i = 0; i < 2; i++) {
+ FoldFrozenConvBatchnorm(graph);
+ FoldFrozenConvAddOrSub(graph);
+ FoldFrozenConvMulOrDiv(graph);
+ }
+}
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.h b/torch/csrc/jit/passes/frozen_graph_optimizations.h
new file mode 100644
index 0000000..a808791
--- /dev/null
+++ b/torch/csrc/jit/passes/frozen_graph_optimizations.h
@@ -0,0 +1,19 @@
+#pragma once
+
+#include <torch/csrc/jit/ir/ir.h>
+
+/** \brief Runs a set of Optimizations that Optimize Frozen Graphs
+ *
+ * Currently this set of optimizations is:
+ * - FoldFrozenConvBatchnorm
+ * - FoldFrozenConvAddOrSub
+ * - FoldFrozenConvMulOrDiv
+ */
+
+namespace torch {
+namespace jit {
+
+TORCH_API void OptimizeFrozenGraph(std::shared_ptr<Graph>& graph);
+
+} // namespace jit
+} // namespace torch
diff --git a/torch/csrc/jit/python/init.cpp b/torch/csrc/jit/python/init.cpp
index 75b6831..99a0330 100644
--- a/torch/csrc/jit/python/init.cpp
+++ b/torch/csrc/jit/python/init.cpp
@@ -23,6 +23,7 @@
#include <torch/csrc/jit/passes/fold_conv_bn.h>
#include <torch/csrc/jit/passes/freeze_module.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
+#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/fuse_linear.h>
#include <torch/csrc/jit/passes/fuse_relu.h>
#include <torch/csrc/jit/passes/graph_fuser.h>
@@ -299,6 +300,7 @@
.def("_jit_pass_fold_frozen_conv_bn", &FoldFrozenConvBatchnorm)
.def("_jit_pass_fold_frozen_conv_add_or_sub", &FoldFrozenConvAddOrSub)
.def("_jit_pass_fold_frozen_conv_mul_or_div", &FoldFrozenConvMulOrDiv)
+ .def("_jit_pass_optimize_frozen_graph", &OptimizeFrozenGraph)
.def("_jit_pass_fuse_linear", &FuseLinear)
.def(
"_jit_pass_fuse_add_relu",
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index cfd3271..bf36436 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -45,7 +45,7 @@
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
from torch.jit.cuda import stream
-from torch.jit._freeze import freeze
+from torch.jit._freeze import freeze, optimize_frozen_module
# For backwards compatibility
_fork = fork
@@ -93,20 +93,20 @@
return _script_if_tracing(fn)
-# for torch.jit.isinstance
+# for torch.jit.isinstance
def isinstance(obj, target_type):
"""
- This function provides for conatiner type refinement in TorchScript. It can refine
+ This function provides for conatiner type refinement in TorchScript. It can refine
parameterized containers of the List, Dict, Tuple, and Optional types. E.g. ``List[str]``,
- ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
+ ``Dict[str, List[torch.Tensor]]``, ``Optional[Tuple[int,str,int]]``. It can also
refine basic types such as bools and ints that are available in TorchScript.
Args:
obj: object to refine the type of
- target_type: type to try to refine obj to
+ target_type: type to try to refine obj to
Returns:
- ``bool``: True if obj was successfully refined to the type of target_type,
- False otherwise with no new type refinement
+ ``bool``: True if obj was successfully refined to the type of target_type,
+ False otherwise with no new type refinement
Example (using ``torch.jit.isinstance`` for type refinement):
diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py
index 5db57f0..98e53c6 100644
--- a/torch/jit/_freeze.py
+++ b/torch/jit/_freeze.py
@@ -10,7 +10,7 @@
from torch.jit._script import RecursiveScriptModule, ScriptModule
-def freeze(mod, preserved_attrs: Optional[List[str]] = None):
+def freeze(mod, preserved_attrs: Optional[List[str]] = None, optimize: bool = True):
r"""
Freezing a :class:`ScriptModule` will clone it and attempt to inline the cloned
module's submodules, parameters, and attributes as constants in the TorchScript IR Graph.
@@ -26,6 +26,11 @@
preserved_attrs (Optional[List[str]]): a list of attributes to preserve in addition to the forward method.
Attributes modified in preserved methods will also be preserved.
+ optimize (bool): If ``True``, a set of optimization passes will be run to prepare the graph for inference,
+ in addition to the graph cleanup that already occurs. The details of the optimizations can be found in
+ `torch.jit.optimize_frozen_module.`
+
+
Returns:
Frozen :class:`ScriptModule`.
@@ -97,5 +102,42 @@
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
+ if optimize:
+ optimize_frozen_module(out)
return out
+
+
+def optimize_frozen_module(mod):
+ r"""
+ Runs a series of optimizations looking for patterns that occur in frozen graphs.
+ The current set of optimizations is:
+ - Conv -> Batchnorm folding
+ - Conv -> Add/Sub folding
+ - Conv -> Mul/Div folding
+
+ Args:
+ mod (:class:`ScriptModule`): a frozen module to be optimized
+
+ Returns:
+ None
+
+ Note:
+ In rare occassions, this can result in slower execution.
+
+ Example (Freezing a module with Conv->Batchnorm)
+ .. code-block:: python
+ import torch
+ in_channels, out_channels = 3, 32
+ conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
+ bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
+ mod = torch.nn.Sequential(conv, bn)
+ # set optimize to False here, by default freezing runs optimize_frozen_module
+ frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
+ # inspect frozen mod
+ assert "batch_norm" in str(frozen_mod.graph)
+ torch.jit.optimize_frozen_module(frozen_mod)
+ assert "batch_norm" not in str(frozen_mod.graph)
+
+ """
+ torch._C._jit_pass_optimize_frozen_graph(mod.graph)