add hardtanh(0,6) to the set of MKLDNN fusible ops for mobilenetv2 (#56203)

Summary:
TODO: post the numbers for mobilenetv2

Pull Request resolved: https://github.com/pytorch/pytorch/pull/56203

Reviewed By: malfet

Differential Revision: D27917557

Pulled By: Krovatkin

fbshipit-source-id: acea0f933a7e8c7a036a494295f68222c46a36f7
diff --git a/aten/src/ATen/core/interned_strings.h b/aten/src/ATen/core/interned_strings.h
index a33ee84..4f64e17 100644
--- a/aten/src/ATen/core/interned_strings.h
+++ b/aten/src/ATen/core/interned_strings.h
@@ -33,7 +33,7 @@
   _(prim, MKLDNNGroup)               \
   _(prim, MKLDNNHardSwish)           \
   _(prim, MKLDNNHardSigmoid)         \
-  _(prim, MKLDNNRelu6)               \
+  _(prim, MKLDNNHardTanh)            \
   _(prim, Drop)                      \
   _(prim, Eval)                      \
   _(prim, Expand) /* onnx */         \
@@ -337,6 +337,7 @@
   _(aten, hardswish)                 \
   _(aten, hardswish_)                \
   _(aten, hardsigmoid_)              \
+  _(aten, hardtanh_)                 \
   FORALL_ATEN_BASE_SYMBOLS(_)        \
   _(onnx, Add)                       \
   _(onnx, Concat)                    \
diff --git a/test/backward_compatibility/check_backward_compatibility.py b/test/backward_compatibility/check_backward_compatibility.py
index 4387aac..2a2bab2 100644
--- a/test/backward_compatibility/check_backward_compatibility.py
+++ b/test/backward_compatibility/check_backward_compatibility.py
@@ -30,6 +30,8 @@
     # Internal
     ("static", datetime.date(9999, 1, 1)),
     ("prim::ModuleDictIndex", datetime.date(9999, 1, 1)),
+    ("prim::MKLDNNRelu6", datetime.date(9999, 1, 1)),
+    ("prim::MKLDNNRelu6_", datetime.date(9999, 1, 1)),
     # Internal, profiler-specific ops
     ("profiler::_call_end_callbacks_on_jit_fut*", datetime.date(9999, 1, 1)),
     ("profiler::_record_function_enter", datetime.date(9999, 1, 1)),
diff --git a/test/jit/test_freezing.py b/test/jit/test_freezing.py
index aadc8d9..846e599 100644
--- a/test/jit/test_freezing.py
+++ b/test/jit/test_freezing.py
@@ -1866,14 +1866,17 @@
     def test_conv_hardswish(self):
         with set_default_dtype(torch.float):
             activations = [
-                torch.nn.Hardswish,
-                torch.nn.Hardsigmoid,
-                torch.nn.ReLU6,
+                torch.nn.Hardswish(),
+                torch.nn.Hardsigmoid(),
+                torch.nn.ReLU6(),
+                torch.nn.Hardtanh(0., 6.),
+                torch.nn.Hardtanh(1., 100.),
+                torch.nn.Hardtanh(-100., -1.),
             ]
 
             model = torchvision.models.resnet18()
             for activation in activations:
-                sub_model = torch.nn.Sequential(model.conv1, activation())
+                sub_model = torch.nn.Sequential(model.conv1, activation)
                 sub_model.eval()
                 mod = torch.jit.freeze(torch.jit.script(sub_model))
                 N, C, H, W, = 10, 3, 224, 224
@@ -1887,7 +1890,6 @@
             op_map = {
                 'prim::MKLDNNHardSwish' : F.hardswish,
                 'prim::MKLDNNHardSigmoid' : F.hardsigmoid,
-                'prim::MKLDNNRelu6' : F.relu6
             }
 
             input_sizes = ([0], [1], [3], [1, 3, 8, 8])
diff --git a/torch/csrc/jit/ir/ir.cpp b/torch/csrc/jit/ir/ir.cpp
index 27b2b21..04851c5 100644
--- a/torch/csrc/jit/ir/ir.cpp
+++ b/torch/csrc/jit/ir/ir.cpp
@@ -1289,6 +1289,7 @@
     v->replaceAllUsesWith(new_out);
   }
   replace_node->copyMetadata(this);
+  replace_node->copyAttributes(*this);
   TORCH_INTERNAL_ASSERT(
       (replace_node->maybeOperator() != nullptr) == had_operator,
       "invalid symbol replacement:",
diff --git a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
index f5565dd..5ab1eb9 100644
--- a/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
+++ b/torch/csrc/jit/passes/frozen_ops_to_mkldnn.cpp
@@ -25,6 +25,7 @@
 #include <ATen/native/ConvUtils.h>
 #include <algorithm>
 #include <memory>
+#include <ATen/core/stack.h>
 #include <c10/core/Layout.h>
 #include <c10/util/StringUtil.h>
 
@@ -180,7 +181,7 @@
     auto k = node->kind();
     if (k == aten::relu || k == aten::sigmoid || k == aten::dropout ||
         k == prim::MKLDNNHardSwish || k == prim::MKLDNNHardSigmoid ||
-        k == prim::MKLDNNRelu6) {
+        k == prim::MKLDNNHardTanh) {
       if (set_liveness[alias_mapping[node->inputs().at(0)]]->isAfter(node)) {
         continue;
       }
@@ -349,6 +350,15 @@
   };
 }
 
+static std::function<void(at::Tensor output, at::Tensor input)> hardtanh_helper(
+    const Node* n) {
+  auto min_val = n->f(attr::min_val);
+  auto max_val = n->f(attr::max_val);
+  return [min_val, max_val](at::Tensor output, at::Tensor input) {
+    at::cpu::hardtanh_out(output, input, min_val, max_val);
+  };
+}
+
 // any op added to this registry needs to meet
 // the precondition: `aten_op(0) == 0`
 const RegisterOperators MKLDNNHardSwishOpReg({
@@ -369,12 +379,10 @@
             true),
         AliasAnalysisKind::FROM_SCHEMA),
     torch::jit::Operator(
-        "prim::MKLDNNRelu6_(Tensor(a!) self) -> Tensor(a!)",
-        createUnaryOp(
-            [](at::Tensor output, at::Tensor input) {
-              at::cpu::hardtanh_out(output, input, 0.f, 6.f);
-            },
-            true),
+        "prim::MKLDNNHardTanh_(Tensor(a!) self) -> Tensor(a!)",
+        [](const Node* n) -> Operation {
+          return createUnaryOp(hardtanh_helper(n), true);
+        },
         AliasAnalysisKind::FROM_SCHEMA),
     torch::jit::Operator(
         "prim::MKLDNNHardSwish(Tensor a) -> Tensor",
@@ -393,12 +401,10 @@
             false),
         AliasAnalysisKind::FROM_SCHEMA),
     torch::jit::Operator(
-        "prim::MKLDNNRelu6(Tensor(a!) self) -> Tensor(a!)",
-        createUnaryOp(
-            [](at::Tensor output, at::Tensor input) {
-              at::cpu::hardtanh_out(output, input, 0.f, 6.f);
-            },
-            false),
+        "prim::MKLDNNHardTanh(Tensor self) -> Tensor",
+        [](const Node* n) -> Operation {
+          return createUnaryOp(hardtanh_helper(n), false);
+        },
         AliasAnalysisKind::FROM_SCHEMA),
 });
 
@@ -569,6 +575,25 @@
   }
 }
 
+static void hartanh_node_creator(
+    Node* body_node,
+    double min_val,
+    double max_val) {
+  WithInsertPoint insert_guard{body_node};
+  auto out_node = body_node->owningGraph()->create(
+      {prim::MKLDNNHardTanh}, {body_node->input(0)}, 1);
+  // N.B. we can't use `insert` as it calls `getOperation` (via
+  // `emitBuiltinCall`) which uses `min_val` and `max_val` attrs which we
+  // haven't set yet.
+  body_node->owningGraph()->insertNode(out_node);
+  auto out_val = out_node->output();
+  out_node->f_(attr::min_val, min_val);
+  out_node->f_(attr::max_val, max_val);
+  out_val->copyMetadata(body_node->output());
+  body_node->output()->replaceAllUsesWith(out_val);
+  body_node->destroy();
+}
+
 void ComputeSubgraphInMKLDNN(Node* subgraph_node) {
   auto graph = subgraph_node->owningGraph();
   Value* none_value = nullptr;
@@ -633,8 +658,16 @@
     }
 
     if (body_node->kind() == aten::relu6) {
-      body_node->replaceWithNewSymbol(prim::MKLDNNRelu6);
-      body_node->destroy();
+      hartanh_node_creator(body_node, 0., 6.);
+      continue;
+    }
+
+    if (body_node->kind() == aten::hardtanh) {
+      auto min_val =
+          constant_as<double>(body_node->namedInput("min_val")).value();
+      auto max_val =
+          constant_as<double>(body_node->namedInput("max_val")).value();
+      hartanh_node_creator(body_node, min_val, max_val);
       continue;
     }
 
@@ -816,10 +849,19 @@
       // conversions. from initial testing including it speeds up models
       case aten::max_pool2d:
       case aten::max_pool3d:
-      case aten::adaptive_avg_pool2d:
         return true;
     }
 
+    if (n->kind() == aten::hardtanh && !nonConstantParameters(n)) {
+      auto min_val = constant_as<double>(n->namedInput("min_val")).value();
+      auto max_val = constant_as<double>(n->namedInput("max_val")).value();
+      // we need to maintain the following invariant `pointwise_func(0) == 0`,
+      // see `createUnaryOp`
+      if (min_val <= 0. && max_val >= 0.) {
+        return true;
+      }
+    }
+
     if (n->kind() == aten::add || n->kind() == aten::mul) {
       // mkldnn doesn't currently support Tensor-Scalar add
       for (size_t i = 0; i < 2; i++) {
@@ -949,6 +991,7 @@
           aten::dropout_,
           aten::sigmoid_,
           aten::hardsigmoid_,
+          aten::hardtanh_,
       };
       return mkldnn_ops.count(node_to_functionalize->kind()) != 0;
     });