add hardtanh(0,6) to the set of MKLDNN fusible ops for mobilenetv2 (#56203)
Summary:
TODO: post the numbers for mobilenetv2
Pull Request resolved: https://github.com/pytorch/pytorch/pull/56203
Reviewed By: malfet
Differential Revision: D27917557
Pulled By: Krovatkin
fbshipit-source-id: acea0f933a7e8c7a036a494295f68222c46a36f7
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index a33ee84..4f64e17 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -33,7 +33,7 @@
_(prim, MKLDNNGroup) \
_(prim, MKLDNNHardSwish) \
_(prim, MKLDNNHardSigmoid) \
- _(prim, MKLDNNRelu6) \
+ _(prim, MKLDNNHardTanh) \
_(prim, Drop) \
_(prim, Eval) \
_(prim, Expand) /* onnx */ \
@@ -337,6 +337,7 @@
_(aten, hardswish) \
_(aten, hardswish_) \
_(aten, hardsigmoid_) \
+ _(aten, hardtanh_) \
FORALL_ATEN_BASE_SYMBOLS(_) \
_(onnx, Add) \
_(onnx, Concat) \
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 4387aac..2a2bab2 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -30,6 +30,8 @@
# Internal
("static", datetime.date(9999, 1, 1)),
("prim::ModuleDictIndex", datetime.date(9999, 1, 1)),
+ ("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)),
+ ("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)),
# Internal, profiler-specific ops
("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
("profiler::_record_function_enter", datetime.date(9999, 1, 1)),
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index aadc8d9..846e599 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1866,14 +1866,17 @@
def test_conv_hardswish(self):
with set_default_dtype(torch.float):
activations = [
- torch.nn.Hardswish,
- torch.nn.Hardsigmoid,
- torch.nn.ReLU6,
+ torch.nn.Hardswish(),
+ torch.nn.Hardsigmoid(),
+ torch.nn.ReLU6(),
+ torch.nn.Hardtanh(0., 6.),
+ torch.nn.Hardtanh(1., 100.),
+ torch.nn.Hardtanh(-100., -1.),
]
model = torchvision.models.resnet18()
for activation in activations:
- sub_model = torch.nn.Sequential(model.conv1, activation())
+ sub_model = torch.nn.Sequential(model.conv1, activation)
sub_model.eval()
mod = torch.jit.freeze(torch.jit.script(sub_model))
N, C, H, W, = 10, 3, 224, 224
@@ -1887,7 +1890,6 @@
op_map = {
'prim::MKLDNNHardSwish' : F.hardswish,
'prim::MKLDNNHardSigmoid' : F.hardsigmoid,
- 'prim::MKLDNNRelu6' : F.relu6
}
input_sizes = ([0], [1], [3], [1, 3, 8, 8])
diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp
index 27b2b21..04851c5 100644
--- a/torch/csrc/jit/ir/ir.cpp
+++ b/torch/csrc/jit/ir/ir.cpp
@@ -1289,6 +1289,7 @@
v->replaceAllUsesWith(new_out);
}
replace_node->copyMetadata(this);
+ replace_node->copyAttributes(*this);
TORCH_INTERNAL_ASSERT(
(replace_node->maybeOperator() != nullptr) == had_operator,
"invalid symbol replacement:",
diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
index f5565dd..5ab1eb9 100644
--- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
+++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
@@ -25,6 +25,7 @@
#include <ATen/native/ConvUtils.h>
#include <algorithm>
#include <memory>
+#include <ATen/core/stack.h>
#include <c10/core/Layout.h>
#include <c10/util/StringUtil.h>
@@ -180,7 +181,7 @@
auto k = node->kind();
if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
- k == prim::MKLDNNRelu6) {
+ k == prim::MKLDNNHardTanh) {
if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
continue;
}
@@ -349,6 +350,15 @@
};
}
+static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
+ const Node* n) {
+ auto min_val = n->f(attr::min_val);
+ auto max_val = n->f(attr::max_val);
+ return [min_val, max_val](at::Tensor output, at::Tensor input) {
+ at::cpu::hardtanh_out(output, input, min_val, max_val);
+ };
+}
+
// any op added to this registry needs to meet
// the precondition: `aten_op(0) == 0`
const RegisterOperators MKLDNNHardSwishOpReg({
@@ -369,12 +379,10 @@
true),
AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
- "prim::MKLDNNRelu6_(Tensor(a!) self) -> Tensor(a!)",
- createUnaryOp(
- [](at::Tensor output, at::Tensor input) {
- at::cpu::hardtanh_out(output, input, 0.f, 6.f);
- },
- true),
+ "prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
+ [](const Node* n) -> Operation {
+ return createUnaryOp(hardtanh_helper(n), true);
+ },
AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
"prim::MKLDNNHardSwish(Tensor a) -> Tensor",
@@ -393,12 +401,10 @@
false),
AliasAnalysisKind::FROM_SCHEMA),
torch::jit::Operator(
- "prim::MKLDNNRelu6(Tensor(a!) self) -> Tensor(a!)",
- createUnaryOp(
- [](at::Tensor output, at::Tensor input) {
- at::cpu::hardtanh_out(output, input, 0.f, 6.f);
- },
- false),
+ "prim::MKLDNNHardTanh(Tensor self) -> Tensor",
+ [](const Node* n) -> Operation {
+ return createUnaryOp(hardtanh_helper(n), false);
+ },
AliasAnalysisKind::FROM_SCHEMA),
});
@@ -569,6 +575,25 @@
}
}
+static void hartanh_node_creator(
+ Node* body_node,
+ double min_val,
+ double max_val) {
+ WithInsertPoint insert_guard{body_node};
+ auto out_node = body_node->owningGraph()->create(
+ {prim::MKLDNNHardTanh}, {body_node->input(0)}, 1);
+ // N.B. we can't use `insert` as it calls `getOperation` (via
+ // `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
+ // haven't set yet.
+ body_node->owningGraph()->insertNode(out_node);
+ auto out_val = out_node->output();
+ out_node->f_(attr::min_val, min_val);
+ out_node->f_(attr::max_val, max_val);
+ out_val->copyMetadata(body_node->output());
+ body_node->output()->replaceAllUsesWith(out_val);
+ body_node->destroy();
+}
+
void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
auto graph = subgraph_node->owningGraph();
Value* none_value = nullptr;
@@ -633,8 +658,16 @@
}
if (body_node->kind() == aten::relu6) {
- body_node->replaceWithNewSymbol(prim::MKLDNNRelu6);
- body_node->destroy();
+ hartanh_node_creator(body_node, 0., 6.);
+ continue;
+ }
+
+ if (body_node->kind() == aten::hardtanh) {
+ auto min_val =
+ constant_as<double>(body_node->namedInput("min_val")).value();
+ auto max_val =
+ constant_as<double>(body_node->namedInput("max_val")).value();
+ hartanh_node_creator(body_node, min_val, max_val);
continue;
}
@@ -816,10 +849,19 @@
// conversions. from initial testing including it speeds up models
case aten::max_pool2d:
case aten::max_pool3d:
- case aten::adaptive_avg_pool2d:
return true;
}
+ if (n->kind() == aten::hardtanh && !nonConstantParameters(n)) {
+ auto min_val = constant_as<double>(n->namedInput("min_val")).value();
+ auto max_val = constant_as<double>(n->namedInput("max_val")).value();
+ // we need to maintain the following invariant `pointwise_func(0) == 0`,
+ // see `createUnaryOp`
+ if (min_val <= 0. && max_val >= 0.) {
+ return true;
+ }
+ }
+
if (n->kind() == aten::add || n->kind() == aten::mul) {
// mkldnn doesn't currently support Tensor-Scalar add
for (size_t i = 0; i < 2; i++) {
@@ -949,6 +991,7 @@
aten::dropout_,
aten::sigmoid_,
aten::hardsigmoid_,
+ aten::hardtanh_,
};
return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
});