add all pools, Batchnorm and Tanh (i.e. all ideeped MKLDNN ops) to MKLDNNFuser (#56541)
Summary:
Fixes #{issue number}
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56541
Reviewed By: pbelevich
Differential Revision: D27930353
Pulled By: Krovatkin
fbshipit-source-id: 4d5b932bad4154e8bdd6e35498354e13b39c87a1
diff --git a/aten/src/ATen/core/aten_interned_strings.h b/aten/src/ATen/core/aten_interned_strings.h
index 69a2239..aa1d921 100644
--- a/aten/src/ATen/core/aten_interned_strings.h
+++ b/aten/src/ATen/core/aten_interned_strings.h
@@ -677,6 +677,7 @@
_(aten, take_along_dim) \
_(aten, tan) \
_(aten, tanh) \
+_(aten, tanh_) \
_(aten, tensor) \
_(aten, tensordot) \
_(aten, tensor_split) \
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index 846e599..b7f956b 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -37,6 +37,10 @@
torch.ones(1).cuda() # initialize cuda context
TEST_CUDNN = TEST_CUDA and torch.backends.cudnn.is_acceptable(torch.tensor(1., device=torch.device('cuda:0')))
+def removeExceptions(graph):
+ for n in graph.findAllNodes('prim::RaiseException'):
+ n.destroy()
+
class TestFreezing(JitTestCase):
def test_freeze_module(self):
class M(nn.Module):
@@ -1850,16 +1854,52 @@
self.assertTrue(torch.allclose(model(inp), mod(inp)))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
- def test_adaptive_avgpool2d(self):
+ def test_pool2d_batchnorm(self):
with set_default_dtype(torch.float):
- sub_model = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 2, 2), torch.nn.AdaptiveAvgPool2d(4), torch.nn.Hardswish())
- sub_model.eval()
- mod = torch.jit.freeze(torch.jit.script(sub_model))
- N, C, H, W, = 10, 3, 224, 224
- inp = torch.randn(N, C, H, W)
- self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
- self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
+ pooling_layers = [torch.nn.AdaptiveAvgPool2d(4),
+ # torch.nn.AdaptiveMaxPool2d(4), # return tuples
+ torch.nn.MaxPool2d(4),
+ torch.nn.AvgPool2d(4),
+ torch.nn.BatchNorm2d(64).eval()]
+
+ for pl in pooling_layers:
+ sub_model = torch.nn.Sequential(torch.nn.Conv2d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish())
+ sub_model.eval()
+ mod = torch.jit.freeze(torch.jit.script(sub_model))
+ N, C, H, W, = 10, 3, 224, 224
+ inp = torch.randn(N, C, H, W)
+ # these two passes needed to remove
+ # a size check in BatchNorm2d
+ removeExceptions(mod.graph)
+ self.run_pass('dce', mod.graph)
+ self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
+ FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
+ self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
+
+ @unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
+ def test_pool3d_batchnorm(self):
+ with set_default_dtype(torch.float):
+
+ pooling_layers = [torch.nn.MaxPool3d(4),
+ # torch.nn.AdaptiveAvgPool3d(4), # no ideep bindings
+ # torch.nn.AdaptiveMaxPool3d(4), # return tuples
+ torch.nn.AvgPool3d(4),
+ torch.nn.BatchNorm3d(64).eval()]
+
+ for pl in pooling_layers:
+ sub_model = torch.nn.Sequential(torch.nn.Conv3d(3, 64, 2, 2), torch.nn.ReLU(), pl, torch.nn.Hardswish())
+ sub_model.eval()
+ mod = torch.jit.freeze(torch.jit.script(sub_model))
+ N, C, H, W, D = 10, 3, 64, 64, 64
+ inp = torch.randn(N, C, D, H, W)
+ # these two passes needed to remove
+ # a size check in BatchNorm2d
+ removeExceptions(mod.graph)
+ self.run_pass('dce', mod.graph)
+ self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
+ FileCheck().check("aten::to_dense").check_next("return").run(mod.graph)
+ self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
@skipIfNoTorchVision
@@ -1869,6 +1909,7 @@
torch.nn.Hardswish(),
torch.nn.Hardsigmoid(),
torch.nn.ReLU6(),
+ torch.nn.Tanh(),
torch.nn.Hardtanh(0., 6.),
torch.nn.Hardtanh(1., 100.),
torch.nn.Hardtanh(-100., -1.),
@@ -1882,6 +1923,7 @@
N, C, H, W, = 10, 3, 224, 224
inp = torch.randn(N, C, H, W)
self.run_pass("convert_frozen_ops_to_mkldnn", mod.graph)
+ FileCheck().check_count("aten::to_dense", 1, exactly=True).run(mod.graph)
self.assertTrue(torch.allclose(sub_model(inp), mod(inp)))
@unittest.skipIf(not torch._C.has_mkldnn, "MKL-DNN build is disabled")
diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
index 5ab1eb9..779093a 100644
--- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
+++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
@@ -181,7 +181,7 @@
auto k = node->kind();
if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
- k == prim::MKLDNNHardTanh) {
+ k == prim::MKLDNNHardTanh || k == aten::tanh) {
if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
continue;
}
@@ -844,11 +844,19 @@
case aten::sigmoid:
case aten::hardsigmoid:
case aten::hardswish:
+ case aten::tanh:
+ case aten::batch_norm:
// TODO: max_pool on mkldnn can be slower than in eager. ideally, we'd
// only fuse it if we knew including max_pool lead to fewer layout
// conversions. from initial testing including it speeds up models
case aten::max_pool2d:
case aten::max_pool3d:
+ case aten::avg_pool2d:
+ case aten::adaptive_avg_pool2d:
+ case aten::avg_pool3d:
+ // case aten::adaptive_max_pool2d: // return tuples which break fusion
+ // case aten::adaptive_max_pool3d: // return tuples which break fusion
+ // case aten::adaptive_avg_pool3d: // no ideep binding
return true;
}
@@ -992,6 +1000,7 @@
aten::sigmoid_,
aten::hardsigmoid_,
aten::hardtanh_,
+ aten::tanh_,
};
return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
});