[JIT] Add optimize_for_inference API (#58193)
Summary:
Freezing exists as a pass which partially evaluates your model and applies generic optimizations which should speed it up. Optimize for inference is a counterpart to these optimizations which runs build & server specific optimizations. The interaction with existing `optimize_frozen_module` is not great, I guess we could just deprecate the API entirely? it was never officially released but just existed to document the `optimize_numerics` keyword.
Eventually, I would like to add a way of adding example inputs but I didnt add that here because they are not being used at all yet. I also have not yet included a way to blacklist individual optimizations, and would like to wait until we move this to Beta and have a little more clarity on how everything will fit together. I also think blacklisting will be an uncommon use case for the current optimizations.
Pull Request resolved: https://github.com/pytorch/pytorch/pull/58193
Reviewed By: bertmaher, navahgar
Differential Revision: D28443714
Pulled By: eellison
fbshipit-source-id: b032355bb2585720a6d2f00c89d0d9a7ef60e649
diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp
index 3c670e4..9fd692b 100644
--- a/test/cpp/jit/test_module_api.cpp
+++ b/test/cpp/jit/test_module_api.cpp
@@ -365,7 +365,10 @@
auto frozen_mod = torch::jit::freeze(m);
auto forward_g = frozen_mod.get_method("forward").graph();
testing::FileCheck().check_not("GetAttr")->run(*forward_g);
- ;
+
+ auto frozen_mod2 = torch::jit::optimize_for_inference(m);
+ forward_g = frozen_mod.get_method("forward").graph();
+ testing::FileCheck().check_not("GetAttr")->run(*forward_g);
}
// NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index d2d0c14..e45d144 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1590,14 +1590,14 @@
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
- # set optimize to False here, by default freezing runs optimize_frozen_module
+ # set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
# inspect frozen mod
FileCheck().check("batch_norm").run(frozen_mod.graph)
- torch.jit.optimize_frozen_module(frozen_mod)
+ torch.jit.run_frozen_optimizations(frozen_mod)
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
- # optimize_frozen_module should be run
+ # run_frozen_optimizations should be run
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
FileCheck().check_not("batch_norm").run(frozen_mod.graph)
@@ -1849,7 +1849,7 @@
else:
scripted_mod = torch.jit.script(mod_eager)
- frozen_mod = torch.jit.freeze(scripted_mod)
+ frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
if add_z:
FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
else:
@@ -1993,6 +1993,19 @@
# and we aren't testing aten impls anyways
self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense()))
+ @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
+ def test_optimize_for_inference(self):
+ with set_default_dtype(torch.float):
+ mod = nn.Linear(20, 30).eval()
+ scripted_mod = torch.jit.script(mod)
+
+ optimized = torch.jit.optimize_for_inference(scripted_mod)
+ FileCheck().check("to_mkldnn").run(optimized.graph)
+
+ frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
+ optimized = torch.jit.optimize_for_inference(scripted_mod)
+ FileCheck().check("to_mkldnn").run(optimized.graph)
+
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
class TestMKLDNNReinplacing(JitTestCase):
def setUp(self):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index b53ab58..233bbe5 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -182,6 +182,7 @@
def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
+def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
def _is_tracing() -> _bool: ...
diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp
index a5ab172..a2a55d1 100644
--- a/torch/csrc/jit/api/module.cpp
+++ b/torch/csrc/jit/api/module.cpp
@@ -10,7 +10,9 @@
#include <torch/csrc/jit/jit_log.h>
#include <torch/csrc/jit/passes/dead_code_elimination.h>
#include <torch/csrc/jit/passes/freeze_module.h>
+#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
+#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
#include <torch/csrc/jit/passes/inliner.h>
#include <torch/csrc/jit/runtime/operator.h>
@@ -484,6 +486,18 @@
return out_mod;
}
+Module optimize_for_inference(Module& module) {
+ // not frozen yet
+ if (module._ivalue()->type()->hasAttribute("training")) {
+ auto mod = freeze(module, {}, true);
+ }
+
+ auto graph = module.get_method("forward").graph();
+ FuseFrozenConvAddRelu(graph);
+ ConvertFrozenOpsToMKLDNN(graph);
+ return module;
+}
+
buffer_list Module::buffers(bool recurse) const {
return buffer_list(*this, recurse, /*return_module=*/false);
}
diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h
index ecb8468..4940898 100644
--- a/torch/csrc/jit/api/module.h
+++ b/torch/csrc/jit/api/module.h
@@ -295,6 +295,10 @@
c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
bool optimize_numerics = true);
+// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation
+// there for details.
+TORCH_API Module optimize_for_inference(Module& module);
+
namespace detail {
struct TORCH_API SlotCursor {
diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
index b228f4e..ca1af11 100644
--- a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
+++ b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
@@ -1,6 +1,5 @@
#include <torch/csrc/jit/ir/alias_analysis.h>
#include <torch/csrc/jit/ir/ir_views.h>
-#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
#include <torch/csrc/jit/passes/frozen_conv_folding.h>
#include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
#include <torch/csrc/jit/passes/remove_dropout.h>
@@ -22,7 +21,6 @@
FoldFrozenConvMulOrDiv(graph);
}
}
- FuseFrozenConvAddRelu(graph);
}
} // namespace jit
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 6018f4a..bb5a45e 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -49,7 +49,7 @@
from torch.jit._serialization import save, load
from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
-from torch.jit._freeze import freeze, optimize_frozen_module
+from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
# For backwards compatibility
_fork = fork
diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py
index a193d9c..cab6d3c 100644
--- a/torch/jit/_freeze.py
+++ b/torch/jit/_freeze.py
@@ -20,6 +20,10 @@
Freezing currently only accepts ScriptModules that are in eval mode.
+ Freezing applies generic optimization that will speed up your model regardless of machine.
+ To further optimize using server-specific settings, run `optimize_for_inference` after
+ freezing.
+
Args:
mod (:class:`ScriptModule`): a module to be frozen
@@ -27,7 +31,7 @@
Attributes modified in preserved methods will also be preserved.
optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
- preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
+ preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
Returns:
Frozen :class:`ScriptModule`.
@@ -83,6 +87,12 @@
If you're not sure why an attribute is not being inlined as a constant, you can run
`dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
attribute is being modified.
+
+ Note:
+ Because freezing makes weights constants and removes module hierarchy, `to` and other
+ nn.Module methods to manipulate device or dtype no longer work. As a workaround,
+ You can remap devices by specifying `map_location` in `torch.jit.load`, however
+ device-specific logic may have been baked into the model.
"""
if not isinstance(mod, ScriptModule):
raise RuntimeError(
@@ -100,12 +110,11 @@
out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
RecursiveScriptModule._finalize_scriptmodule(out)
- optimize_frozen_module(out, optimize_numerics)
+ run_frozen_optimizations(out, optimize_numerics)
return out
-
-def optimize_frozen_module(mod, optimize_numerics: bool = True):
+def run_frozen_optimizations(mod, optimize_numerics: bool = True):
r"""
Runs a series of optimizations looking for patterns that occur in frozen graphs.
The current set of optimizations is:
@@ -136,11 +145,11 @@
conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
mod = torch.nn.Sequential(conv, bn)
- # set optimize to False here, by default freezing runs optimize_frozen_module
+ # set optimize to False here, by default freezing runs run_frozen_optimizations
frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
# inspect frozen mod
assert "batch_norm" in str(frozen_mod.graph)
- torch.jit.optimize_frozen_module(frozen_mod)
+ torch.jit.run_frozen_optimizations(frozen_mod)
assert "batch_norm" not in str(frozen_mod.graph)
"""
@@ -153,4 +162,32 @@
torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)
+
+def optimize_for_inference(mod: ScriptModule) -> ScriptModule:
+ """
+ Performs a set of optimization passes to optimize a model for the
+ purposes of inference. If the model is not already frozen, optimize_for_inference
+ will invoke `torch.jit.freeze` automatically.
+
+ In addition to generic optimizations that should speed up your model regardless
+ of environment, prepare for inference will also bake in build specific settings
+ such as the presence of CUDNN or MKLDNN, and may in the future make transformations
+ which speed things up on one machine but slow things down on another. Accordingly,
+ serialization is not implemented following invoking `optimize_for_inference` and
+ is not guaranteed.
+
+ This is still in prototype, and may have the potential to slow down your model.
+ Primary use cases that have been targeted so far have been vision models on cpu
+ and gpu to a lesser extent.
+ """
+ if not isinstance(mod, ScriptModule):
+ raise RuntimeError(
+ "optimize_for_inference expects a ScriptModule as input. "
+ "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
+
+ if hasattr(mod, "training"):
+ mod = freeze(mod.eval())
+
+ torch._C._jit_pass_convert_frozen_ops_to_mkldnn(mod.graph)
torch._C._jit_pass_fuse_frozen_conv_add_relu(mod.graph)
+ return mod