[JIT] Add optimize_for_inference API (#58193)

Summary:
Freezing exists as a pass which partially evaluates your model and applies generic optimizations which should speed it up. Optimize for inference is a counterpart to these optimizations which runs build & server specific optimizations.  The interaction with existing `optimize_frozen_module` is not great, I guess we could just deprecate the API entirely? it was never officially released but just existed to document the `optimize_numerics` keyword.

Eventually, I would like to add a way of adding example inputs but I didnt add that here because they are not being used at all yet. I also have not yet included a way to blacklist individual optimizations, and would like to wait until we move this to Beta and have a little more clarity on how everything will fit together. I also think blacklisting will be an uncommon use case for the current optimizations.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/58193

Reviewed By: bertmaher, navahgar

Differential Revision: D28443714

Pulled By: eellison

fbshipit-source-id: b032355bb2585720a6d2f00c89d0d9a7ef60e649
diff --git a/test/cpp/jit/test_module_api.cpp b/test/cpp/jit/test_module_api.cpp
index 3c670e4..9fd692b 100644
--- a/test/cpp/jit/test_module_api.cpp
+++ b/test/cpp/jit/test_module_api.cpp
@@ -365,7 +365,10 @@
   auto frozen_mod = torch::jit::freeze(m);
   auto forward_g = frozen_mod.get_method("forward").graph();
   testing::FileCheck().check_not("GetAttr")->run(*forward_g);
-  ;
+
+  auto frozen_mod2 = torch::jit::optimize_for_inference(m);
+  forward_g = frozen_mod.get_method("forward").graph();
+  testing::FileCheck().check_not("GetAttr")->run(*forward_g);
 }
 
 // NOLINTNEXTLINE(cppcoreguidelines-avoid-non-const-global-variables)
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index d2d0c14..e45d144 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1590,14 +1590,14 @@
         conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
         bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
         mod = torch.nn.Sequential(conv, bn)
-        # set optimize to False here, by default freezing runs optimize_frozen_module
+        # set optimize to False here, by default freezing runs run_frozen_optimizations
         frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize_numerics=False)
         # inspect frozen mod
         FileCheck().check("batch_norm").run(frozen_mod.graph)
-        torch.jit.optimize_frozen_module(frozen_mod)
+        torch.jit.run_frozen_optimizations(frozen_mod)
         FileCheck().check_not("batch_norm").run(frozen_mod.graph)
 
-        # optimize_frozen_module should be run
+        # run_frozen_optimizations should be run
         frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
         FileCheck().check_not("batch_norm").run(frozen_mod.graph)
 
@@ -1849,7 +1849,7 @@
             else:
                 scripted_mod = torch.jit.script(mod_eager)
 
-            frozen_mod = torch.jit.freeze(scripted_mod)
+            frozen_mod = torch.jit.optimize_for_inference(scripted_mod)
             if add_z:
                 FileCheck().check("aten::cudnn_convolution_add_relu").run(frozen_mod.graph)
             else:
@@ -1993,6 +1993,19 @@
                         # and we aren't testing aten impls anyways
                         self.assertTrue(torch.allclose(aten_op(x, inplace=False), m(x).to_dense()))
 
+    @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
+    def test_optimize_for_inference(self):
+        with set_default_dtype(torch.float):
+            mod = nn.Linear(20, 30).eval()
+            scripted_mod = torch.jit.script(mod)
+
+            optimized = torch.jit.optimize_for_inference(scripted_mod)
+            FileCheck().check("to_mkldnn").run(optimized.graph)
+
+            frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()))
+            optimized = torch.jit.optimize_for_inference(scripted_mod)
+            FileCheck().check("to_mkldnn").run(optimized.graph)
+
 @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
 class TestMKLDNNReinplacing(JitTestCase):
     def setUp(self):
diff --git a/torch/_C/__init__.pyi.in b/torch/_C/__init__.pyi.in
index b53ab58..233bbe5 100644
--- a/torch/_C/__init__.pyi.in
+++ b/torch/_C/__init__.pyi.in
@@ -182,6 +182,7 @@
 def _jit_pass_fold_frozen_conv_add_or_sub(graph: Graph): ...
 def _jit_pass_fold_frozen_conv_mul_or_div(graph: Graph): ...
 def _jit_pass_fuse_frozen_conv_add_relu(graph: Graph): ...
+def _jit_pass_convert_frozen_ops_to_mkldnn(graph: Graph): ...
 def _jit_pass_remove_dropout(module: 'torch.jit.ScriptModule'): ...
 
 def _is_tracing() -> _bool: ...
diff --git a/torch/csrc/jit/api/module.cpp b/torch/csrc/jit/api/module.cpp
index a5ab172..a2a55d1 100644
--- a/torch/csrc/jit/api/module.cpp
+++ b/torch/csrc/jit/api/module.cpp
@@ -10,7 +10,9 @@
 #include <torch/csrc/jit/jit_log.h>
 #include <torch/csrc/jit/passes/dead_code_elimination.h>
 #include <torch/csrc/jit/passes/freeze_module.h>
+#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
+#include <torch/csrc/jit/passes/frozen_ops_to_mkldnn.h>
 #include <torch/csrc/jit/passes/inliner.h>
 #include <torch/csrc/jit/runtime/operator.h>
 
@@ -484,6 +486,18 @@
   return out_mod;
 }
 
+Module optimize_for_inference(Module& module) {
+  // not frozen yet
+  if (module._ivalue()->type()->hasAttribute("training")) {
+    auto mod = freeze(module, {}, true);
+  }
+
+  auto graph = module.get_method("forward").graph();
+  FuseFrozenConvAddRelu(graph);
+  ConvertFrozenOpsToMKLDNN(graph);
+  return module;
+}
+
 buffer_list Module::buffers(bool recurse) const {
   return buffer_list(*this, recurse, /*return_module=*/false);
 }
diff --git a/torch/csrc/jit/api/module.h b/torch/csrc/jit/api/module.h
index ecb8468..4940898 100644
--- a/torch/csrc/jit/api/module.h
+++ b/torch/csrc/jit/api/module.h
@@ -295,6 +295,10 @@
     c10::optional<std::vector<std::string>> preserved_attrs = c10::nullopt,
     bool optimize_numerics = true);
 
+// C++ equivalent api of `torch.jit.optimize_for_inference`. See documentation
+// there for details.
+TORCH_API Module optimize_for_inference(Module& module);
+
 namespace detail {
 
 struct TORCH_API SlotCursor {
diff --git a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
index b228f4e..ca1af11 100644
--- a/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
+++ b/torch/csrc/jit/passes/frozen_graph_optimizations.cpp
@@ -1,6 +1,5 @@
 #include <torch/csrc/jit/ir/alias_analysis.h>
 #include <torch/csrc/jit/ir/ir_views.h>
-#include <torch/csrc/jit/passes/frozen_conv_add_relu_fusion.h>
 #include <torch/csrc/jit/passes/frozen_conv_folding.h>
 #include <torch/csrc/jit/passes/frozen_graph_optimizations.h>
 #include <torch/csrc/jit/passes/remove_dropout.h>
@@ -22,7 +21,6 @@
       FoldFrozenConvMulOrDiv(graph);
     }
   }
-  FuseFrozenConvAddRelu(graph);
 }
 
 } // namespace jit
diff --git a/torch/jit/__init__.py b/torch/jit/__init__.py
index 6018f4a..bb5a45e 100644
--- a/torch/jit/__init__.py
+++ b/torch/jit/__init__.py
@@ -49,7 +49,7 @@
 from torch.jit._serialization import save, load
 from torch.jit._fuser import optimized_execution, fuser, last_executed_optimized_graph
 
-from torch.jit._freeze import freeze, optimize_frozen_module
+from torch.jit._freeze import freeze, optimize_for_inference, run_frozen_optimizations
 
 # For backwards compatibility
 _fork = fork
diff --git a/torch/jit/_freeze.py b/torch/jit/_freeze.py
index a193d9c..cab6d3c 100644
--- a/torch/jit/_freeze.py
+++ b/torch/jit/_freeze.py
@@ -20,6 +20,10 @@
 
     Freezing currently only accepts ScriptModules that are in eval mode.
 
+    Freezing applies generic optimization that will speed up your model regardless of machine.
+    To further optimize using server-specific settings, run `optimize_for_inference` after
+    freezing.
+
     Args:
         mod (:class:`ScriptModule`): a module to be frozen
 
@@ -27,7 +31,7 @@
         Attributes modified in preserved methods will also be preserved.
 
         optimize_numerics (bool): If ``True``, a set of optimization passes will be run that does not strictly
-        preserve numerics. Full details of optimization can be found at `torch.jit.optimize_frozen_module`.
+        preserve numerics. Full details of optimization can be found at `torch.jit.run_frozen_optimizations`.
 
     Returns:
         Frozen :class:`ScriptModule`.
@@ -83,6 +87,12 @@
         If you're not sure why an attribute is not being inlined as a constant, you can run
         `dump_alias_db` on frozen_module.forward.graph to see if freezing has detected the
         attribute is being modified.
+
+    Note:
+        Because freezing makes weights constants and removes module hierarchy, `to` and other
+        nn.Module methods to manipulate device or dtype no longer work. As a workaround,
+        You can remap devices by specifying `map_location` in `torch.jit.load`, however
+        device-specific logic may have been baked into the model.
     """
     if not isinstance(mod, ScriptModule):
         raise RuntimeError(
@@ -100,12 +110,11 @@
 
     out = RecursiveScriptModule(torch._C._freeze_module(mod._c, preserved_attrs))
     RecursiveScriptModule._finalize_scriptmodule(out)
-    optimize_frozen_module(out, optimize_numerics)
+    run_frozen_optimizations(out, optimize_numerics)
 
     return out
 
-
-def optimize_frozen_module(mod, optimize_numerics: bool = True):
+def run_frozen_optimizations(mod, optimize_numerics: bool = True):
     r"""
     Runs a series of optimizations looking for patterns that occur in frozen graphs.
     The current set of optimizations is:
@@ -136,11 +145,11 @@
         conv = torch.nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=2, bias=True)
         bn = torch.nn.BatchNorm2d(out_channels, eps=.001)
         mod = torch.nn.Sequential(conv, bn)
-        # set optimize to False here, by default freezing runs optimize_frozen_module
+        # set optimize to False here, by default freezing runs run_frozen_optimizations
         frozen_mod = torch.jit.freeze(torch.jit.script(mod.eval()), optimize=False)
         # inspect frozen mod
         assert "batch_norm" in str(frozen_mod.graph)
-        torch.jit.optimize_frozen_module(frozen_mod)
+        torch.jit.run_frozen_optimizations(frozen_mod)
         assert "batch_norm" not in str(frozen_mod.graph)
 
     """
@@ -153,4 +162,32 @@
             torch._C._jit_pass_fold_frozen_conv_bn(mod.graph)
             torch._C._jit_pass_fold_frozen_conv_add_or_sub(mod.graph)
             torch._C._jit_pass_fold_frozen_conv_mul_or_div(mod.graph)
+
+def optimize_for_inference(mod: ScriptModule) -> ScriptModule:
+    """
+    Performs a set of optimization passes to optimize a model for the
+    purposes of inference. If the model is not already frozen, optimize_for_inference
+    will invoke `torch.jit.freeze` automatically.
+
+    In addition to generic optimizations that should speed up your model regardless
+    of environment, prepare for inference will also bake in build specific settings
+    such as the presence of CUDNN or MKLDNN, and may in the future make transformations
+    which speed things up on one machine but slow things down on another. Accordingly,
+    serialization is not implemented following invoking `optimize_for_inference` and
+    is not guaranteed.
+
+    This is still in prototype, and may have the potential to slow down your model.
+    Primary use cases that have been targeted so far have been vision models on cpu
+    and gpu to a lesser extent.
+    """
+    if not isinstance(mod, ScriptModule):
+        raise RuntimeError(
+            "optimize_for_inference expects a ScriptModule as input. "
+            "Please use torch.jit.script or torch.jit.trace to script your 'nn.Module'.")
+
+    if hasattr(mod, "training"):
+        mod = freeze(mod.eval())
+
+    torch._C._jit_pass_convert_frozen_ops_to_mkldnn(mod.graph)
     torch._C._jit_pass_fuse_frozen_conv_add_relu(mod.graph)
+    return mod