add all pools, Batchnorm and Tanh (i.e. all ideeped MKLDNN ops) to MKLDNNFuser (#56541)

Summary:
Fixes #{issue number}

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56541

Reviewed By: pbelevich

Differential Revision: D27930353

Pulled By: Krovatkin

fbshipit-source-id: 4d5b932bad4154e8bdd6e35498354e13b39c87a1
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 69a2239..aa1d921 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -677,6 +677,7 @@
 _(aten, take_along_dim) \
 _(aten, tan) \
 _(aten, tanh) \
+_(aten, tanh_) \
 _(aten, tensor) \
 _(aten, tensordot) \
 _(aten, tensor_split) \
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 846e599..b7f956b 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -37,6 +37,10 @@
     torch.ones(1).cuda()  # initialize cuda context
     TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0')))
 
+def removeExceptions(graph):
+    for n in graph.findAllNodes('prim::RaiseException'):
+        n.destroy()
+
 class TestFreezing(JitTestCase):
     def test_freeze_module(self):
         class M(nn.Module):
@@ -1850,16 +1854,52 @@
             self.assertTrue(torch.allclose(model(inp), mod(inp)))
 
     @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
-    def test_adaptive_avgpool2d(self):
+    def test_pool2d_batchnorm(self):
         with set_default_dtype(torch.float):
 
-            sub_model = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 2, 2), torch.nn.AdaptiveAvgPool2d(4), torch.nn.Hardswish())
-            sub_model.eval()
-            mod = torch.jit.freeze(torch.jit.script(sub_model))
-            N, C, H, W, = 10, 3, 224, 224
-            inp = torch.randn(N, C, H, W)
-            self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
-            self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
+            pooling_layers = [torch.nn.AdaptiveAvgPool2d(4),
+                              # torch.nn.AdaptiveMaxPool2d(4), # return tuples
+                              torch.nn.MaxPool2d(4),
+                              torch.nn.AvgPool2d(4),
+                              torch.nn.BatchNorm2d(64).eval()]
+
+            for pl in pooling_layers:
+                sub_model = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish())
+                sub_model.eval()
+                mod = torch.jit.freeze(torch.jit.script(sub_model))
+                N, C, H, W, = 10, 3, 224, 224
+                inp = torch.randn(N, C, H, W)
+                # these two passes needed to remove
+                # a size check in BatchNorm2d
+                removeExceptions(mod.graph)
+                self.run_pass('dce', mod.graph)
+                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
+                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
+                self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
+
+    @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
+    def test_pool3d_batchnorm(self):
+        with set_default_dtype(torch.float):
+
+            pooling_layers = [torch.nn.MaxPool3d(4),
+                              # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings
+                              # torch.nn.AdaptiveMaxPool3d(4), # return tuples
+                              torch.nn.AvgPool3d(4),
+                              torch.nn.BatchNorm3d(64).eval()]
+
+            for pl in pooling_layers:
+                sub_model = torch.nn.Sequential(torch.nn.Conv3d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish())
+                sub_model.eval()
+                mod = torch.jit.freeze(torch.jit.script(sub_model))
+                N, C, H, W, D = 10, 3, 64, 64, 64
+                inp = torch.randn(N, C, D, H, W)
+                # these two passes needed to remove
+                # a size check in BatchNorm2d
+                removeExceptions(mod.graph)
+                self.run_pass('dce', mod.graph)
+                self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
+                FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
+                self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
 
     @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
     @skipIfNoTorchVision
@@ -1869,6 +1909,7 @@
                 torch.nn.Hardswish(),
                 torch.nn.Hardsigmoid(),
                 torch.nn.ReLU6(),
+                torch.nn.Tanh(),
                 torch.nn.Hardtanh(0., 6.),
                 torch.nn.Hardtanh(1., 100.),
                 torch.nn.Hardtanh(-100., -1.),
@@ -1882,6 +1923,7 @@
                 N, C, H, W, = 10, 3, 224, 224
                 inp = torch.randn(N, C, H, W)
                 self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
+                FileCheck().check_count("aten::to_dense", 1, exactly=True).run(mod.graph)
                 self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
 
     @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
index 5ab1eb9..779093a 100644
--- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
+++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
@@ -181,7 +181,7 @@
     auto k = node->kind();
     if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
         k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
-        k == prim::MKLDNNHardTanh) {
+        k == prim::MKLDNNHardTanh || k == aten::tanh) {
       if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
         continue;
       }
@@ -844,11 +844,19 @@
       case aten::sigmoid:
       case aten::hardsigmoid:
       case aten::hardswish:
+      case aten::tanh:
+      case aten::batch_norm:
       // TODO: max_pool on mkldnn can be slower than in eager. ideally, we'd
       // only fuse it if we knew including max_pool lead to fewer layout
       // conversions. from initial testing including it speeds up models
       case aten::max_pool2d:
       case aten::max_pool3d:
+      case aten::avg_pool2d:
+      case aten::adaptive_avg_pool2d:
+      case aten::avg_pool3d:
+        // case aten::adaptive_max_pool2d: // return tuples which break fusion
+        // case aten::adaptive_max_pool3d: // return tuples which break fusion
+        // case aten::adaptive_avg_pool3d: // no ideep binding
         return true;
     }
 
@@ -992,6 +1000,7 @@
           aten::sigmoid_,
           aten::hardsigmoid_,
           aten::hardtanh_,
+          aten::tanh_,
       };
       return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
     });